lightly-train 0.2.3__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/build_docker_image.yml +2 -2
  2. {lightly_train-0.2.3 → lightly_train-0.2.4}/DOCS.md +26 -5
  3. {lightly_train-0.2.3 → lightly_train-0.2.4}/Makefile +3 -3
  4. {lightly_train-0.2.3/src/lightly_train.egg-info → lightly_train-0.2.4}/PKG-INFO +3 -1
  5. {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/Makefile +2 -0
  6. {lightly_train-0.2.3 → lightly_train-0.2.4}/pyproject.toml +3 -0
  7. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/__init__.py +1 -1
  8. lightly_train-0.2.4/src/lightly_train/_callbacks/callback_args.py +45 -0
  9. lightly_train-0.2.4/src/lightly_train/_callbacks/callback_helpers.py +75 -0
  10. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_callbacks/checkpoint.py +11 -1
  11. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_cli.py +18 -4
  12. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/export.py +2 -2
  13. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/train.py +33 -1
  14. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/train_helpers.py +5 -46
  15. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/jsonl.py +12 -3
  16. lightly_train-0.2.4/src/lightly_train/_loggers/logger_args.py +23 -0
  17. lightly_train-0.2.4/src/lightly_train/_loggers/logger_helpers.py +89 -0
  18. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/tensorboard.py +14 -0
  19. lightly_train-0.2.4/src/lightly_train/_loggers/wandb.py +32 -0
  20. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/densecl.py +6 -6
  21. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/densecldino.py +5 -6
  22. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/dino.py +6 -6
  23. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method.py +48 -1
  24. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/simclr.py +7 -8
  25. lightly_train-0.2.4/src/lightly_train/_plot.py +95 -0
  26. {lightly_train-0.2.3 → lightly_train-0.2.4/src/lightly_train.egg-info}/PKG-INFO +3 -1
  27. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/SOURCES.txt +10 -0
  28. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/requires.txt +3 -0
  29. lightly_train-0.2.4/tests/_callbacks/test_callback_helpers.py +154 -0
  30. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_common_helpers.py +3 -0
  31. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_embed.py +28 -5
  32. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_train.py +52 -3
  33. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_train_helpers.py +5 -4
  34. lightly_train-0.2.4/tests/_loggers/test_logger_helpers.py +158 -0
  35. lightly_train-0.2.4/tests/_optim/__init__.py +7 -0
  36. lightly_train-0.2.4/tests/test__plot.py +31 -0
  37. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/pull_request_template.md +0 -0
  38. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/release_dockerhub.yml +0 -0
  39. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/release_pypi.yml +0 -0
  40. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test.yml +0 -0
  41. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_code_format.yml +0 -0
  42. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_docker.yml +0 -0
  43. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_minimal_deps.yml +0 -0
  44. {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/weekly_dependency_tests.yml +0 -0
  45. {lightly_train-0.2.3 → lightly_train-0.2.4}/.gitignore +0 -0
  46. {lightly_train-0.2.3 → lightly_train-0.2.4}/CONTRIBUTE.md +0 -0
  47. {lightly_train-0.2.3 → lightly_train-0.2.4}/LICENSE +0 -0
  48. {lightly_train-0.2.3 → lightly_train-0.2.4}/README.md +0 -0
  49. {lightly_train-0.2.3 → lightly_train-0.2.4}/dev_tools/licenseheader.tmpl +0 -0
  50. {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/Dockerfile-amd64-cuda +0 -0
  51. {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/README.md +0 -0
  52. {lightly_train-0.2.3 → lightly_train-0.2.4}/setup.cfg +0 -0
  53. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_callbacks/__init__.py +0 -0
  54. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_checkpoint.py +0 -0
  55. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/__init__.py +0 -0
  56. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/_warnings.py +0 -0
  57. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/common_helpers.py +0 -0
  58. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/embed.py +0 -0
  59. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/extract_video_frames.py +0 -0
  60. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/__init__.py +0 -0
  61. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/config.py +0 -0
  62. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/omegaconf_utils.py +0 -0
  63. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/validate.py +0 -0
  64. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_constants.py +0 -0
  65. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/__init__.py +0 -0
  66. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_format.py +0 -0
  67. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_predictor.py +0 -0
  68. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_transform.py +0 -0
  69. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/__init__.py +0 -0
  70. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/csv_writer.py +0 -0
  71. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/embedding_writer.py +0 -0
  72. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/torch_writer.py +0 -0
  73. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/writer_helpers.py +0 -0
  74. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/__init__.py +0 -0
  75. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_logging.py +0 -0
  76. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/__init__.py +0 -0
  77. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method_args.py +0 -0
  78. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method_helpers.py +0 -0
  79. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/__init__.py +0 -0
  80. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/__init__.py +0 -0
  81. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/custom.py +0 -0
  82. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/custom_package.py +0 -0
  83. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding/__init__.py +0 -0
  84. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding/base.py +0 -0
  85. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding_model.py +0 -0
  86. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/feature_extractor.py +0 -0
  87. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/package.py +0 -0
  88. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/package_helpers.py +0 -0
  89. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/__init__.py +0 -0
  90. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/customizable_detector.py +0 -0
  91. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/super_gradients.py +0 -0
  92. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/super_gradients_package.py +0 -0
  93. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/__init__.py +0 -0
  94. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/timm.py +0 -0
  95. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/timm_package.py +0 -0
  96. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/__init__.py +0 -0
  97. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/convnext.py +0 -0
  98. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/resnet.py +0 -0
  99. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/torchvision.py +0 -0
  100. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/torchvision_package.py +0 -0
  101. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/__init__.py +0 -0
  102. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/adamw_args.py +0 -0
  103. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer.py +0 -0
  104. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer_args.py +0 -0
  105. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer_type.py +0 -0
  106. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/trainable_modules.py +0 -0
  107. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_scaling.py +0 -0
  108. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_transforms.py +0 -0
  109. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/errors.py +0 -0
  110. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/types.py +0 -0
  111. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/dependency_links.txt +0 -0
  112. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/entry_points.txt +0 -0
  113. {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/top_level.txt +0 -0
  114. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/__init__.py +0 -0
  115. {lightly_train-0.2.3/tests/_commands → lightly_train-0.2.4/tests/_callbacks}/__init__.py +0 -0
  116. {lightly_train-0.2.3/tests/_configs → lightly_train-0.2.4/tests/_commands}/__init__.py +0 -0
  117. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_export.py +0 -0
  118. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_extract_video_frames.py +0 -0
  119. {lightly_train-0.2.3/tests/_embedding → lightly_train-0.2.4/tests/_configs}/__init__.py +0 -0
  120. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_configs/test_validate.py +0 -0
  121. {lightly_train-0.2.3/tests/_embedding/writers → lightly_train-0.2.4/tests/_embedding}/__init__.py +0 -0
  122. {lightly_train-0.2.3/tests/_methods → lightly_train-0.2.4/tests/_embedding/writers}/__init__.py +0 -0
  123. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_csv_writer.py +0 -0
  124. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_torch_writer.py +0 -0
  125. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_writer_helpers.py +0 -0
  126. {lightly_train-0.2.3/tests/_models → lightly_train-0.2.4/tests/_methods}/__init__.py +0 -0
  127. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_methods/test_method_helpers.py +0 -0
  128. {lightly_train-0.2.3/tests/_models/custom → lightly_train-0.2.4/tests/_models}/__init__.py +0 -0
  129. {lightly_train-0.2.3/tests/_models/timm → lightly_train-0.2.4/tests/_models/custom}/__init__.py +0 -0
  130. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/custom/test_custom.py +0 -0
  131. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/custom/test_custom_package.py +0 -0
  132. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/super_gradients/test_super_gradients_package.py +0 -0
  133. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/test_package_helpers.py +0 -0
  134. {lightly_train-0.2.3/tests/_optim → lightly_train-0.2.4/tests/_models/timm}/__init__.py +0 -0
  135. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/timm/test_timm.py +0 -0
  136. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/timm/test_timm_package.py +0 -0
  137. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/torchvision/test_torchvision_package.py +0 -0
  138. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_optim/test_optimizer.py +0 -0
  139. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/helpers.py +0 -0
  140. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__checkpoint.py +0 -0
  141. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__cli.py +0 -0
  142. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__logging.py +0 -0
  143. {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__scaling.py +0 -0
@@ -52,7 +52,7 @@ jobs:
52
52
  if: ${{ needs.build-docker-image.result == 'success' }}
53
53
  uses: rtCamp/action-slack-notify@v2
54
54
  env:
55
- SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES }}
55
+ SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES_DEV }}
56
56
  SLACK_ICON_EMOJI: ":train:"
57
57
  SLACK_USERNAME: "Build of LightlyTrain Docker Image"
58
58
  SLACK_COLOR: "good"
@@ -71,7 +71,7 @@ jobs:
71
71
  if: ${{ needs.build-docker-image.result != 'success' }}
72
72
  uses: rtCamp/action-slack-notify@v2
73
73
  env:
74
- SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES }}
74
+ SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES_DEV }}
75
75
  SLACK_ICON_EMOJI: ":x:"
76
76
  SLACK_USERNAME: "Build of LightlyTrain Docker Image"
77
77
  SLACK_COLOR: "danger"
@@ -27,12 +27,20 @@ lightly_train.train(
27
27
 
28
28
  In most cases you only have to specify `out`, `data`, and `model`. The rest is optional.
29
29
 
30
- The training process can be monitored with TensorBoard (requires
31
- `pip install lightly-train[tensorboard]`):
30
+ You can monitor your training process with the help of `tensorboard` and `wandb` loggers:
32
31
  ```
33
- tensorboard --logdir my_output_dir
32
+ pip install "lightly-train[tensorboard, wandb]"
34
33
  ```
35
34
 
35
+ Configure the loggers from Python:
36
+ * `loggers={"tensorboard": True}`: Enable `tensorboard` logger with default arguments.
37
+ * `loggers={"wandb": True}`: Enable `wandb` logger with default arguments.
38
+ * `loggers={"wandb": {"project": "my-project"}}`: Configure `wandb` logger with custom arguments.
39
+
40
+ LightlyTrain uses the PyTorchLightning loggers under the hood. Learn more about their configuration:
41
+ * https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
42
+ * https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html
43
+
36
44
  #### Exporting
37
45
  ```python
38
46
  import lightly_train
@@ -94,8 +102,21 @@ lightly-train train \
94
102
 
95
103
  In most cases you only have to specify `out`, `data`, and `model`. The rest is optional.
96
104
 
97
- The training process can be monitored with TensorBoard (requires
98
- `pip install lightly-train[tensorboard]`):
105
+ You can monitor your training process with the help of `tensorboard` and `wandb` loggers:
106
+ ```
107
+ pip install "lightly-train[tensorboard, wandb]"
108
+ ```
109
+
110
+ Configure the loggers from the command-line:
111
+ * `loggers.tensorboard=True`: Enable `tensorboard` logger with default arguments.
112
+ * `loggers.wandb=True`: Enable `wandb` logger with default arguments.
113
+ * `loggers.wandb.project="my-project"`: Configure `wandb` logger with custom arguments.
114
+
115
+ LightlyTrain uses the PyTorchLightning loggers under the hood. Learn more about their configuration:
116
+ * https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
117
+ * https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html
118
+
119
+
99
120
  ```
100
121
  tensorboard --logdir my_output_dir
101
122
  ```
@@ -72,7 +72,7 @@ test:
72
72
 
73
73
  .PHONY: test-ci
74
74
  test-ci:
75
- pytest tests --capture=no -v
75
+ pytest tests -v
76
76
 
77
77
 
78
78
  ### Virtual Environment
@@ -105,8 +105,8 @@ endif
105
105
 
106
106
  # SuperGradients is not compatible with Python>=3.10. It is also not easy to install
107
107
  # on MacOS. Therefore we exclude it from the default extras.
108
- EXTRAS = [dev,tensorboard,timm]
109
- DOCKER_EXTRAS = --extra tensorboard --extra timm
108
+ EXTRAS = [dev,tensorboard,timm,wandb]
109
+ DOCKER_EXTRAS = --extra tensorboard --extra timm --extra wandb
110
110
 
111
111
  # Date until which dependencies installed with --exclude-newer must have been released.
112
112
  # Dependencies released after this date are ignored.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lightly_train
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Train models with self-supervised learning in a single command
5
5
  Author: Lightly Team
6
6
  License: AGPL-3.0
@@ -28,6 +28,8 @@ Provides-Extra: tensorboard
28
28
  Requires-Dist: tensorboard>=2.10.0; extra == "tensorboard"
29
29
  Provides-Extra: timm
30
30
  Requires-Dist: timm>=1.0.3; extra == "timm"
31
+ Provides-Extra: wandb
32
+ Requires-Dist: wandb>=0.12.10; extra == "wandb"
31
33
 
32
34
  # LightlyTrain
33
35
 
@@ -69,6 +69,7 @@ test:
69
69
  -v $(LIGHTLY_TRAIN_OUT):/out \
70
70
  -v $(LIGHTLY_TRAIN_DATA):/data \
71
71
  -v ./Makefile:/home/lightly_train/docker/Makefile \
72
+ --gpus all \
72
73
  lightly/$(IMAGE):$(TAG) make test-cli-from-within-docker -C docker
73
74
 
74
75
  # This target is run from within the docker container
@@ -76,6 +77,7 @@ test-cli-from-within-docker:
76
77
  @echo "Test train"
77
78
  lightly-train train data=/data out=/out model="torchvision/convnext_small" epochs=2 batch_size=2 model_args.weights="IMAGENET1K_V1" devices=2
78
79
  test -f /out/checkpoints/last.ckpt
80
+ test `grep -c "GPU available: True (cuda), used: True" /out/train.log` -gt 0
79
81
  @echo "Test embed"
80
82
  lightly-train embed data=/data out="/out/embeddings.csv" checkpoint="/out/checkpoints/last.ckpt" batch_size=2 format="csv"
81
83
  test `wc -l < /out/embeddings.csv` -eq 6
@@ -52,6 +52,9 @@ tensorboard = [
52
52
  timm = [
53
53
  "timm>=1.0.3",
54
54
  ]
55
+ wandb = [
56
+ "wandb>=0.12.10", # required by pytorch-lightning
57
+ ]
55
58
 
56
59
  [project.scripts]
57
60
  lightly-train = "lightly_train._cli:_cli_entrypoint"
@@ -31,4 +31,4 @@ __all__ = [
31
31
  "train",
32
32
  ]
33
33
 
34
- __version__ = "0.2.3"
34
+ __version__ = "0.2.4"
@@ -0,0 +1,45 @@
1
+ #
2
+ # Copyright (c) Lightly AG and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ from lightly_train._callbacks.checkpoint import ModelCheckpointArgs
12
+ from lightly_train._configs.config import Config
13
+
14
+
15
+ @dataclass
16
+ class LearningRateMonitorArgs(Config):
17
+ pass
18
+
19
+
20
+ @dataclass
21
+ class DeviceStatsMonitorArgs(Config):
22
+ pass
23
+
24
+
25
+ @dataclass
26
+ class EarlyStoppingArgs(Config):
27
+ monitor: str = "train_loss"
28
+ patience: int = int(1e12)
29
+ check_finite: bool = True
30
+
31
+
32
+ @dataclass
33
+ class CallbackArgs(Config):
34
+ learning_rate_monitor: Optional[LearningRateMonitorArgs] = field(
35
+ default_factory=LearningRateMonitorArgs
36
+ )
37
+ device_stats_monitor: Optional[DeviceStatsMonitorArgs] = field(
38
+ default_factory=DeviceStatsMonitorArgs
39
+ )
40
+ early_stopping: Optional[EarlyStoppingArgs] = field(
41
+ default_factory=EarlyStoppingArgs
42
+ )
43
+ model_checkpoint: Optional[ModelCheckpointArgs] = field(
44
+ default_factory=ModelCheckpointArgs
45
+ )
@@ -0,0 +1,75 @@
1
+ #
2
+ # Copyright (c) Lightly AG and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ from __future__ import annotations
9
+
10
+ import dataclasses
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from omegaconf import OmegaConf
15
+ from pytorch_lightning import Callback
16
+ from pytorch_lightning.callbacks import (
17
+ DeviceStatsMonitor,
18
+ EarlyStopping,
19
+ LearningRateMonitor,
20
+ )
21
+ from torch.nn import Module
22
+
23
+ from lightly_train._callbacks.callback_args import (
24
+ CallbackArgs,
25
+ )
26
+ from lightly_train._callbacks.checkpoint import ModelCheckpoint
27
+ from lightly_train._checkpoint import CheckpointLightlyTrainModels
28
+ from lightly_train._configs import validate
29
+ from lightly_train._models.embedding_model import EmbeddingModel
30
+
31
+
32
+ def validate_callback_args(callback_dict: dict[str, Any] | None) -> CallbackArgs:
33
+ callback_dictconfig = OmegaConf.create(
34
+ {} if callback_dict is None else callback_dict
35
+ )
36
+ callback_dictconfig = validate.validate_dictconfig(
37
+ callback_dictconfig, CallbackArgs
38
+ )
39
+ callback_args = OmegaConf.to_object(callback_dictconfig)
40
+ assert isinstance(callback_args, CallbackArgs) # make mypy happy
41
+ return callback_args
42
+
43
+
44
+ def get_callbacks(
45
+ callback_args: CallbackArgs,
46
+ out: Path,
47
+ model: Module,
48
+ embedding_model: EmbeddingModel,
49
+ ):
50
+ callbacks: list[Callback] = []
51
+ if callback_args.learning_rate_monitor is not None:
52
+ callbacks.append(
53
+ LearningRateMonitor(
54
+ **dataclasses.asdict(callback_args.learning_rate_monitor)
55
+ )
56
+ )
57
+ if callback_args.device_stats_monitor is not None:
58
+ callbacks.append(
59
+ DeviceStatsMonitor(**dataclasses.asdict(callback_args.device_stats_monitor))
60
+ )
61
+ if callback_args.early_stopping is not None:
62
+ callbacks.append(
63
+ EarlyStopping(**dataclasses.asdict(callback_args.early_stopping))
64
+ )
65
+ if callback_args.model_checkpoint is not None:
66
+ callbacks.append(
67
+ ModelCheckpoint(
68
+ models=CheckpointLightlyTrainModels(
69
+ model=model, embedding_model=embedding_model
70
+ ),
71
+ dirpath=out / "checkpoints",
72
+ **dataclasses.asdict(callback_args.model_checkpoint),
73
+ )
74
+ )
75
+ return callbacks
@@ -8,8 +8,9 @@
8
8
  from __future__ import annotations
9
9
 
10
10
  import logging
11
+ from dataclasses import dataclass
11
12
  from datetime import timedelta
12
- from typing import Any
13
+ from typing import Any, Optional
13
14
 
14
15
  from pytorch_lightning import LightningModule, Trainer
15
16
  from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint
@@ -19,11 +20,20 @@ from lightly_train._checkpoint import (
19
20
  CheckpointLightlyTrain,
20
21
  CheckpointLightlyTrainModels,
21
22
  )
23
+ from lightly_train._configs.config import Config
22
24
  from lightly_train.types import PathLike
23
25
 
24
26
  logger = logging.getLogger(__name__)
25
27
 
26
28
 
29
+ @dataclass
30
+ class ModelCheckpointArgs(Config):
31
+ save_last: bool = True
32
+ enable_version_counter: bool = False
33
+ save_top_k: int = 1
34
+ every_n_epochs: Optional[int] = None
35
+
36
+
27
37
  class ModelCheckpoint(_ModelCheckpoint):
28
38
  def __init__(
29
39
  self,
@@ -34,7 +34,7 @@ _HELP_MSG = """
34
34
  lightly-train extract_video_frames Extract frames from videos using ffmpeg.
35
35
  lightly-train help Show help message.
36
36
 
37
- Run `lightly_train <command> help` for more information on a specific command.
37
+ Run `lightly-train <command> help` for more information on a specific command.
38
38
 
39
39
  Optional arguments:
40
40
  -v, --verbose Run the command in verbose mode for detailed output.
@@ -105,6 +105,20 @@ _TRAIN_HELP_MSG = f"""
105
105
  For details, see: https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
106
106
  seed (int):
107
107
  Random seed for reproducibility. Default: {_train_cfg.seed}
108
+ loggers (dict):
109
+ Loggers for training. Either null or a dictionary of logger names to either
110
+ null or a dictionary of logger arguments. Null uses the default loggers.
111
+ To disable a logger, set it to null: `loggers.tensorboard=null`.
112
+ To configure a logger, pass the respective arguments:
113
+ `loggers.wandb.project="my_project"`.
114
+ Default: null
115
+ callbacks (dict):
116
+ Callbacks fo training. Either null or a dictionary of callback names to
117
+ either null or a dictionary of callback arguments. Null uses the default
118
+ callbacks. To disable a callback, set it to null:
119
+ `callbacks.model_checkpoint=null`. To configure a callback, pass the
120
+ respective arguments: `callbacks.model_checkpoint.every_n_epochs=5`.
121
+ Default: null
108
122
  optim_args (dict):
109
123
  Arguments for AdamW optimizer. Available arguments are:
110
124
  - lr (float)
@@ -159,10 +173,10 @@ _EXPORT_HELP_MSG = """
159
173
  Format to save the model in. Valid options are 'torch_model' and
160
174
  'torch_state_dict'. 'torch_model' saves the model as a torch module which
161
175
  can be loaded with `model = torch.load(out)`. This requires that the same
162
- lightly_train version is installed when the model is exported and when it is
176
+ LightlyTrain version is installed when the model is exported and when it is
163
177
  loaded again. 'torch_state_dict' saves the model's state dict which can be
164
178
  loaded with `model.load_state_dict(torch.load(out))`. This is more flexible
165
- and can be used to load the model with different lightly_train versions but
179
+ and can be used to load the model with different LightlyTrain versions but
166
180
  requires the model to already be instantiated.
167
181
 
168
182
  Optional arguments:
@@ -399,7 +413,7 @@ def _show_invalid_command_help(command: str) -> None:
399
413
  msg = _format_msg(
400
414
  f"""
401
415
  Unknown command '{command}':
402
- lightly_train {command}
416
+ lightly-train {command}
403
417
  """
404
418
  )
405
419
  msg += "\n"
@@ -52,10 +52,10 @@ def export(
52
52
  Format to save the model in. Valid options are 'torch_model' and
53
53
  'torch_state_dict'. 'torch_model' saves the model as a torch module which
54
54
  can be loaded with `model = torch.load(out)`. This requires that the same
55
- lightly_train version is installed when the model is exported and when it is
55
+ LightlyTrain version is installed when the model is exported and when it is
56
56
  loaded again. 'torch_state_dict' saves the model's state dict which can be
57
57
  loaded with `model.load_state_dict(torch.load(out))`. This is more flexible
58
- and can be used to load the model with different lightly_train versions but
58
+ and can be used to load the model with different LightlyTrain versions but
59
59
  requires the model to already be instantiated.
60
60
  overwrite:
61
61
  Overwrite the output file if it already exists.
@@ -21,7 +21,9 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISIO
21
21
  from torch.nn import Module
22
22
  from torch.utils.data import Dataset
23
23
 
24
+ import lightly_train._loggers.logger_helpers
24
25
  from lightly_train import _logging
26
+ from lightly_train._callbacks import callback_helpers
25
27
  from lightly_train._commands import _warnings, common_helpers, train_helpers
26
28
  from lightly_train._configs import omegaconf_utils, validate
27
29
  from lightly_train._configs.config import Config
@@ -49,6 +51,8 @@ def train(
49
51
  strategy: str | Strategy = "auto",
50
52
  precision: _PRECISION_INPUT = "32-true", # Default precision in PyTorch Lightning
51
53
  seed: int = 0,
54
+ loggers: dict[str, dict[str, Any] | None] | None = None,
55
+ callbacks: dict[str, dict[str, Any] | None] | None = None,
52
56
  optim_args: dict[str, Any] | None = None,
53
57
  transform_args: dict[str, Any] | None = None,
54
58
  loader_args: dict[str, Any] | None = None,
@@ -111,6 +115,19 @@ def train(
111
115
  https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
112
116
  seed:
113
117
  Random seed for reproducibility.
118
+ loggers:
119
+ Loggers for training. Either None or a dictionary of logger names to either
120
+ None or a dictionary of logger arguments. None uses the default loggers.
121
+ To disable a logger, set it to None: `loggers={"tensorboard": None}`.
122
+ To configure a logger, pass the respective arguments:
123
+ `loggers={"wandb": {"project": "my_project"}}`.
124
+ callbacks:
125
+ Callbacks for training. Either None or a dictionary of callback names to
126
+ either None or a dictionary of callback arguments. None uses the default
127
+ callbacks. To disable a callback, set it to None:
128
+ `callbacks={"model_checkpoint": None}`. To configure a callback, pass the
129
+ respective arguments:
130
+ `callbacks={"model_checkpoint": {"every_n_epochs": 5}}`.
114
131
  optim_args:
115
132
  Arguments for AdamW optimizer. Available arguments are:
116
133
  - lr: float
@@ -160,10 +177,21 @@ def train(
160
177
  log_every_n_steps = train_helpers.get_lightning_logging_interval(
161
178
  dataset_size=scaling_info.dataset_size, batch_size=batch_size
162
179
  )
163
- trainer_instance = train_helpers.get_trainer(
180
+ logger_args = lightly_train._loggers.logger_helpers.validate_logger_args(
181
+ loggers=loggers
182
+ )
183
+ logger_instances = lightly_train._loggers.logger_helpers.get_loggers(
184
+ logger_args=logger_args, out=out_dir
185
+ )
186
+ callback_args = callback_helpers.validate_callback_args(callback_dict=callbacks)
187
+ callback_instances = callback_helpers.get_callbacks(
188
+ callback_args=callback_args,
164
189
  out=out_dir,
165
190
  model=model_instance,
166
191
  embedding_model=embedding_model,
192
+ )
193
+ trainer_instance = train_helpers.get_trainer(
194
+ out=out_dir,
167
195
  epochs=epochs,
168
196
  accelerator=accelerator,
169
197
  strategy=strategy,
@@ -171,6 +199,8 @@ def train(
171
199
  num_nodes=num_nodes,
172
200
  precision=precision,
173
201
  log_every_n_steps=log_every_n_steps,
202
+ loggers=logger_instances,
203
+ callbacks=callback_instances,
174
204
  trainer_args=trainer_args,
175
205
  )
176
206
  dataloader = train_helpers.get_dataloader(
@@ -228,6 +258,8 @@ class TrainConfig(Config):
228
258
  strategy: str = "auto"
229
259
  precision: str = "32-true"
230
260
  seed: int = 0
261
+ loggers: Optional[Dict[str, Optional[Dict[str, Any]]]] = None
262
+ callbacks: Optional[Dict[str, Optional[Dict[str, Any]]]] = None
231
263
  optim_args: Optional[Dict[str, Any]] = None
232
264
  transform_args: Optional[Dict[str, Any]] = None
233
265
  loader_args: Optional[Dict[str, Any]] = None
@@ -7,32 +7,24 @@
7
7
  #
8
8
  from __future__ import annotations
9
9
 
10
+ import logging
10
11
  from pathlib import Path
11
12
  from typing import Any, Callable, Sized, Type
12
13
 
13
14
  from lightly.data import LightlyDataset
14
- from pytorch_lightning import Trainer
15
+ from pytorch_lightning import Callback, Trainer
15
16
  from pytorch_lightning.accelerators.accelerator import Accelerator
16
17
  from pytorch_lightning.accelerators.cpu import CPUAccelerator
17
18
  from pytorch_lightning.accelerators.cuda import CUDAAccelerator
18
- from pytorch_lightning.callbacks import (
19
- DeviceStatsMonitor,
20
- EarlyStopping,
21
- LearningRateMonitor,
22
- )
23
19
  from pytorch_lightning.loggers import Logger
24
20
  from pytorch_lightning.strategies.strategy import Strategy
25
21
  from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISION_INPUT
26
22
  from torch.nn import Module
27
23
  from torch.utils.data import DataLoader, Dataset
28
24
 
29
- from lightly_train._callbacks.checkpoint import ModelCheckpoint
30
- from lightly_train._checkpoint import CheckpointLightlyTrainModels
31
25
  from lightly_train._commands import common_helpers
32
26
  from lightly_train._configs import validate
33
27
  from lightly_train._constants import DATALOADER_TIMEOUT
34
- from lightly_train._loggers.jsonl import JSONLLogger
35
- from lightly_train._loggers.tensorboard import TensorBoardLogger
36
28
  from lightly_train._methods import method_helpers
37
29
  from lightly_train._methods.method import Method
38
30
  from lightly_train._models import package_helpers
@@ -42,18 +34,6 @@ from lightly_train._optim.optimizer_type import OptimizerType
42
34
  from lightly_train._scaling import IMAGENET_SIZE, ScalingInfo
43
35
  from lightly_train.types import PathLike, Transform
44
36
 
45
- try:
46
- import timm
47
- except ImportError:
48
- timm = None
49
-
50
- try:
51
- import tensorboard
52
- except ImportError:
53
- tensorboard = None
54
-
55
- import logging
56
-
57
37
  logger = logging.getLogger(__name__)
58
38
 
59
39
 
@@ -166,8 +146,6 @@ def get_embedding_model(model: Module, embed_dim: int | None = None) -> Embeddin
166
146
 
167
147
  def get_trainer(
168
148
  out: Path,
169
- model: Module,
170
- embedding_model: EmbeddingModel,
171
149
  epochs: int,
172
150
  accelerator: str | Accelerator,
173
151
  strategy: str | Strategy,
@@ -175,17 +153,11 @@ def get_trainer(
175
153
  num_nodes: int,
176
154
  log_every_n_steps: int,
177
155
  precision: _PRECISION_INPUT | None,
156
+ loggers: list[Logger],
157
+ callbacks: list[Callback],
178
158
  trainer_args: dict[str, Any] | None,
179
159
  ) -> Trainer:
180
160
  logger.debug("Getting trainer.")
181
- # Set version and name to empty string to save logs directly in the root
182
- # directory.
183
- loggers: list[Logger] = [
184
- JSONLLogger(save_dir=out, name="", version=""),
185
- ]
186
- if tensorboard is not None:
187
- loggers.append(TensorBoardLogger(save_dir=out, name="", version=""))
188
- logger.debug(f"Using loggers {[log.__class__.__name__ for log in loggers]}.")
189
161
 
190
162
  accelerator = common_helpers.get_accelerator(accelerator=accelerator)
191
163
  strategy = get_strategy(accelerator=accelerator, strategy=strategy, devices=devices)
@@ -200,20 +172,7 @@ def get_trainer(
200
172
  num_nodes=num_nodes,
201
173
  precision=precision,
202
174
  log_every_n_steps=log_every_n_steps,
203
- callbacks=[
204
- LearningRateMonitor(),
205
- DeviceStatsMonitor(),
206
- ModelCheckpoint(
207
- save_last=True,
208
- models=CheckpointLightlyTrainModels(
209
- model=model, embedding_model=embedding_model
210
- ),
211
- dirpath=out / "checkpoints",
212
- enable_version_counter=False,
213
- ),
214
- # Stop if training loss diverges.
215
- EarlyStopping(monitor="train_loss", patience=int(1e12), check_finite=True),
216
- ],
175
+ callbacks=callbacks,
217
176
  logger=loggers,
218
177
  sync_batchnorm=sync_batchnorm,
219
178
  )
@@ -5,9 +5,12 @@
5
5
  # This source code is licensed under the license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
  #
8
+ from __future__ import annotations
9
+
8
10
  import json
9
11
  import logging
10
12
  import os
13
+ from dataclasses import dataclass
11
14
  from typing import Dict, List, Optional, Union
12
15
 
13
16
  from lightning_fabric.loggers.logger import rank_zero_experiment
@@ -15,6 +18,7 @@ from lightning_fabric.utilities.rank_zero import rank_zero_warn
15
18
  from pytorch_lightning.loggers import CSVLogger
16
19
  from pytorch_lightning.loggers.csv_logs import ExperimentWriter as CSVExperimentWriter
17
20
 
21
+ from lightly_train._configs.config import Config
18
22
  from lightly_train.types import PathLike
19
23
 
20
24
  log = logging.getLogger(__name__)
@@ -55,6 +59,11 @@ class ExperimentWriter(CSVExperimentWriter):
55
59
  )
56
60
 
57
61
 
62
+ @dataclass
63
+ class JSONLLoggerArgs(Config):
64
+ flush_logs_every_n_steps: int = 100
65
+
66
+
58
67
  class JSONLLogger(CSVLogger):
59
68
  """Log to local file system in JSON Lines format.
60
69
 
@@ -69,7 +78,7 @@ class JSONLLogger(CSVLogger):
69
78
 
70
79
  Args:
71
80
  save_dir: Save directory
72
- name: Experiment name. Defaults to ``'lightning_logs'``.
81
+ name: Experiment name.
73
82
  version:
74
83
  Experiment version. If version is not specified the logger inspects the save
75
84
  directory for existing versions, then automatically assigns the next
@@ -83,8 +92,8 @@ class JSONLLogger(CSVLogger):
83
92
  def __init__(
84
93
  self,
85
94
  save_dir: PathLike,
86
- name: str = "lightning_logs",
87
- version: Optional[Union[int, str]] = None,
95
+ name: str = "",
96
+ version: Optional[Union[int, str]] = "",
88
97
  prefix: str = "",
89
98
  flush_logs_every_n_steps: int = 100,
90
99
  ):
@@ -0,0 +1,23 @@
1
+ #
2
+ # Copyright (c) Lightly AG and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ from lightly_train._configs.config import Config
12
+ from lightly_train._loggers.jsonl import JSONLLoggerArgs
13
+ from lightly_train._loggers.tensorboard import TensorBoardLoggerArgs
14
+ from lightly_train._loggers.wandb import WandbLoggerArgs
15
+
16
+
17
+ @dataclass
18
+ class LoggerArgs(Config):
19
+ jsonl: Optional[JSONLLoggerArgs] = field(default_factory=JSONLLoggerArgs)
20
+ tensorboard: Optional[TensorBoardLoggerArgs] = field(
21
+ default_factory=TensorBoardLoggerArgs
22
+ )
23
+ wandb: Optional[WandbLoggerArgs] = None