lightly-train 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightly_train-0.2.3/.github/pull_request_template.md +7 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.gitignore +1 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/PKG-INFO +1 -1
- {lightly_train-0.2.2 → lightly_train-0.2.3}/docker/Dockerfile-amd64-cuda +7 -8
- {lightly_train-0.2.2 → lightly_train-0.2.3}/docker/Makefile +17 -6
- {lightly_train-0.2.2 → lightly_train-0.2.3}/docker/README.md +14 -1
- {lightly_train-0.2.2 → lightly_train-0.2.3}/pyproject.toml +4 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/__init__.py +9 -1
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_callbacks/checkpoint.py +4 -3
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_checkpoint.py +28 -2
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_cli.py +42 -4
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/_warnings.py +23 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/common_helpers.py +30 -15
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/embed.py +24 -4
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/export.py +17 -3
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/extract_video_frames.py +23 -4
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/train.py +24 -4
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/train_helpers.py +56 -4
- lightly_train-0.2.3/src/lightly_train/_logging.py +171 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/timm/timm.py +4 -2
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_transforms.py +7 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/PKG-INFO +1 -1
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/SOURCES.txt +3 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_common_helpers.py +80 -20
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_train_helpers.py +41 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/test__cli.py +13 -10
- lightly_train-0.2.3/tests/test__logging.py +88 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/build_docker_image.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/release_dockerhub.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/release_pypi.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/test.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/test_code_format.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/test_docker.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/test_minimal_deps.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/.github/workflows/weekly_dependency_tests.yml +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/CONTRIBUTE.md +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/DOCS.md +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/LICENSE +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/Makefile +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/README.md +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/dev_tools/licenseheader.tmpl +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/setup.cfg +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_callbacks/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_commands/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_configs/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_configs/config.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_configs/omegaconf_utils.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_configs/validate.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_constants.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/embedding_format.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/embedding_predictor.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/embedding_transform.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/writers/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/writers/csv_writer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/writers/embedding_writer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/writers/torch_writer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_embedding/writers/writer_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_loggers/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_loggers/jsonl.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_loggers/tensorboard.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/densecl.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/densecldino.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/dino.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/method.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/method_args.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/method_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_methods/simclr.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/custom/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/custom/custom.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/custom/custom_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/embedding/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/embedding/base.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/embedding_model.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/feature_extractor.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/package_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/super_gradients/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/super_gradients/customizable_detector.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/super_gradients/super_gradients.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/super_gradients/super_gradients_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/timm/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/timm/timm_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/torchvision/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/torchvision/convnext.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/torchvision/resnet.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/torchvision/torchvision.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_models/torchvision/torchvision_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/adamw_args.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/optimizer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/optimizer_args.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/optimizer_type.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_optim/trainable_modules.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/_scaling.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/errors.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train/types.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/dependency_links.txt +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/entry_points.txt +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/requires.txt +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/src/lightly_train.egg-info/top_level.txt +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_embed.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_export.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_extract_video_frames.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_commands/test_train.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_configs/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_configs/test_validate.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_embedding/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_embedding/writers/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_embedding/writers/test_csv_writer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_embedding/writers/test_torch_writer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_embedding/writers/test_writer_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_methods/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_methods/test_method_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/custom/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/custom/test_custom.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/custom/test_custom_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/super_gradients/test_super_gradients_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/test_package_helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/timm/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/timm/test_timm.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/timm/test_timm_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_models/torchvision/test_torchvision_package.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_optim/__init__.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/_optim/test_optimizer.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/helpers.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/test__checkpoint.py +0 -0
- {lightly_train-0.2.2 → lightly_train-0.2.3}/tests/test__scaling.py +0 -0
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
## What has changed and why?
|
|
2
|
+
|
|
3
|
+
(Delete this: Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.)
|
|
4
|
+
|
|
5
|
+
## How has it been tested?
|
|
6
|
+
|
|
7
|
+
(Delete this: Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration.)
|
|
@@ -4,10 +4,8 @@
|
|
|
4
4
|
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch
|
|
5
5
|
FROM pytorch/pytorch:2.3.0-cuda11.8-cudnn8-runtime AS runtime
|
|
6
6
|
|
|
7
|
-
#
|
|
8
|
-
ENV
|
|
9
|
-
# Install packages into the system Python and skip creating a virtual environment.
|
|
10
|
-
UV_SYSTEM_PYTHON="true" \
|
|
7
|
+
# Install packages into the system Python and skip creating a virtual environment.
|
|
8
|
+
ENV UV_SYSTEM_PYTHON="true" \
|
|
11
9
|
# Do not cache dependencies as they would also be saved in the docker image.
|
|
12
10
|
UV_NO_CACHE="true"
|
|
13
11
|
|
|
@@ -19,8 +17,8 @@ WORKDIR /home/lightly_train
|
|
|
19
17
|
|
|
20
18
|
# Install uv
|
|
21
19
|
COPY Makefile /home/lightly_train
|
|
22
|
-
|
|
23
20
|
RUN make install-uv
|
|
21
|
+
|
|
24
22
|
# Add uv to PATH
|
|
25
23
|
ENV PATH="/root/.cargo/bin:$PATH"
|
|
26
24
|
|
|
@@ -31,8 +29,9 @@ RUN make install-docker-dependencies
|
|
|
31
29
|
# Copy the package itself
|
|
32
30
|
COPY src /home/lightly_train/src
|
|
33
31
|
|
|
32
|
+
# Set and create the directory to save pretrained torch models into
|
|
33
|
+
ENV TORCH_HOME="/home/lightly_train/.cache/torch"
|
|
34
|
+
RUN mkdir -p ${TORCH_HOME} && chmod -R a+w $TORCH_HOME
|
|
35
|
+
|
|
34
36
|
# Install the package.
|
|
35
37
|
RUN make install-docker
|
|
36
|
-
|
|
37
|
-
# Default command that runs when the container starts.
|
|
38
|
-
ENTRYPOINT ["lightly-train"]
|
|
@@ -62,15 +62,26 @@ test:
|
|
|
62
62
|
@echo "Generate images"
|
|
63
63
|
mkdir -p $(LIGHTLY_TRAIN_DATA)
|
|
64
64
|
python -c 'from PIL import Image; [Image.new("RGB", (250, 300)).save(f"$(LIGHTLY_TRAIN_DATA)/{i}.png") for i in range(5)]'
|
|
65
|
+
@echo "Create output directory"
|
|
66
|
+
mkdir -p $(LIGHTLY_TRAIN_OUT)
|
|
67
|
+
chmod -R +rw $(LIGHTLY_TRAIN_OUT)
|
|
68
|
+
docker run --rm --shm-size=1g --user $(shell id -u):$(shell id -g) \
|
|
69
|
+
-v $(LIGHTLY_TRAIN_OUT):/out \
|
|
70
|
+
-v $(LIGHTLY_TRAIN_DATA):/data \
|
|
71
|
+
-v ./Makefile:/home/lightly_train/docker/Makefile \
|
|
72
|
+
lightly/$(IMAGE):$(TAG) make test-cli-from-within-docker -C docker
|
|
73
|
+
|
|
74
|
+
# This target is run from within the docker container
|
|
75
|
+
test-cli-from-within-docker:
|
|
65
76
|
@echo "Test train"
|
|
66
|
-
|
|
67
|
-
test -f
|
|
77
|
+
lightly-train train data=/data out=/out model="torchvision/convnext_small" epochs=2 batch_size=2 model_args.weights="IMAGENET1K_V1" devices=2
|
|
78
|
+
test -f /out/checkpoints/last.ckpt
|
|
68
79
|
@echo "Test embed"
|
|
69
|
-
|
|
70
|
-
test `wc -l <
|
|
80
|
+
lightly-train embed data=/data out="/out/embeddings.csv" checkpoint="/out/checkpoints/last.ckpt" batch_size=2 format="csv"
|
|
81
|
+
test `wc -l < /out/embeddings.csv` -eq 6
|
|
71
82
|
@echo "Test export"
|
|
72
|
-
|
|
73
|
-
test -f
|
|
83
|
+
lightly-train export out="/out/model.pth" checkpoint="/out/checkpoints/last.ckpt" part="model" format="torch_state_dict"
|
|
84
|
+
test -f /out/model.pth
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
|
|
@@ -21,7 +21,20 @@ TODO
|
|
|
21
21
|
|
|
22
22
|
## Usage
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
First, start the docker container in interactive mode by using the -it flag. Furthermore,
|
|
25
|
+
you must mount the directories you want to use.
|
|
26
|
+
|
|
27
|
+
```
|
|
28
|
+
docker run -it --gpus=all --user $(id -u):$(id -g) -v /my_data_dir:/data -v /my_output_dir:/out lightly/train:latest
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Then all the usual CLI commands are fully available. E.g. run
|
|
32
|
+
|
|
33
|
+
```
|
|
34
|
+
lightly-train train data="/data" out="/out" model="torchvision/convnext_small" method=dino
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
|
|
25
38
|
|
|
26
39
|
|
|
27
40
|
## Development
|
|
@@ -89,3 +89,7 @@ known-first-party = ["lightly_train"]
|
|
|
89
89
|
[tool.ruff.lint.pydocstyle]
|
|
90
90
|
# Use Google-style docstrings.
|
|
91
91
|
convention = "google"
|
|
92
|
+
|
|
93
|
+
[tool.ruff.lint.per-file-ignores]
|
|
94
|
+
# Ignore `E402` (import violations) in root `__init__.py`.
|
|
95
|
+
"src/lightly_train/__init__.py" = ["E402"]
|
|
@@ -5,6 +5,14 @@
|
|
|
5
5
|
# This source code is licensed under the license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
#
|
|
8
|
+
|
|
9
|
+
# Disable beta transforms warning by torchvision.
|
|
10
|
+
# See https://stackoverflow.com/questions/77279407
|
|
11
|
+
# TODO(Philipp, 09/24): Remove this once the warning is removed.
|
|
12
|
+
import torchvision
|
|
13
|
+
|
|
14
|
+
torchvision.disable_beta_transforms_warning()
|
|
15
|
+
|
|
8
16
|
from lightly_train._commands.embed import embed
|
|
9
17
|
from lightly_train._commands.export import ModelFormat, ModelPart, export
|
|
10
18
|
from lightly_train._commands.train import train
|
|
@@ -23,4 +31,4 @@ __all__ = [
|
|
|
23
31
|
"train",
|
|
24
32
|
]
|
|
25
33
|
|
|
26
|
-
__version__ = "0.2.
|
|
34
|
+
__version__ = "0.2.3"
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
#
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import
|
|
10
|
+
import logging
|
|
11
11
|
from datetime import timedelta
|
|
12
12
|
from typing import Any
|
|
13
13
|
|
|
@@ -21,6 +21,8 @@ from lightly_train._checkpoint import (
|
|
|
21
21
|
)
|
|
22
22
|
from lightly_train.types import PathLike
|
|
23
23
|
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
24
26
|
|
|
25
27
|
class ModelCheckpoint(_ModelCheckpoint):
|
|
26
28
|
def __init__(
|
|
@@ -78,7 +80,6 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
78
80
|
checkpoint
|
|
79
81
|
).models
|
|
80
82
|
except KeyError as ex:
|
|
81
|
-
|
|
83
|
+
logger.warning(
|
|
82
84
|
f"Could not restore lightly_train models from checkpoint: {ex}"
|
|
83
85
|
)
|
|
84
|
-
pass
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
|
+
import logging
|
|
11
12
|
from dataclasses import dataclass
|
|
12
13
|
from datetime import datetime, timezone
|
|
13
14
|
from pathlib import Path
|
|
@@ -21,6 +22,8 @@ from torch.serialization import MAP_LOCATION
|
|
|
21
22
|
import lightly_train
|
|
22
23
|
from lightly_train._models.embedding_model import EmbeddingModel
|
|
23
24
|
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
24
27
|
CHECKPOINT_LIGHTLY_TRAIN_KEY = "lightly_train"
|
|
25
28
|
|
|
26
29
|
|
|
@@ -104,11 +107,34 @@ class Checkpoint:
|
|
|
104
107
|
|
|
105
108
|
@staticmethod
|
|
106
109
|
def from_path(
|
|
107
|
-
checkpoint: Path,
|
|
110
|
+
checkpoint: Path,
|
|
111
|
+
map_location: MAP_LOCATION | None = "cpu",
|
|
112
|
+
weights_only: bool = False,
|
|
108
113
|
) -> Checkpoint:
|
|
109
|
-
|
|
114
|
+
"""Load a checkpoint from a file path.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
checkpoint:
|
|
118
|
+
Path to the checkpoint file.
|
|
119
|
+
map_location:
|
|
120
|
+
If map_location is a string, it must be a key in torch.device, such as
|
|
121
|
+
'cpu' or 'cuda:0'. If map_location is a torch.device, it will be used to
|
|
122
|
+
determine where the checkpoint should be loaded to. Default: 'cpu'.
|
|
123
|
+
weights_only:
|
|
124
|
+
If False (default), the whole checkpoint is loaded. If True, only the weights
|
|
125
|
+
of the model are loaded. This requires the user to add safe globals with:
|
|
126
|
+
https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
|
|
127
|
+
TODO(Philipp, 09/24): Expose weights_only argument to the user.
|
|
128
|
+
"""
|
|
129
|
+
logger.debug(
|
|
130
|
+
f"Loading checkpoint from '{checkpoint}' with map_location '{map_location}' and weights_only {weights_only}"
|
|
131
|
+
)
|
|
132
|
+
checkpoint_dict = torch.load(
|
|
133
|
+
checkpoint, map_location=map_location, weights_only=weights_only
|
|
134
|
+
)
|
|
110
135
|
return Checkpoint.from_dict(checkpoint=checkpoint_dict)
|
|
111
136
|
|
|
112
137
|
def save(self, path: Path) -> None:
|
|
138
|
+
logger.debug(f"Saving checkpoint to '{path}'")
|
|
113
139
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
114
140
|
torch.save(self.to_dict(), path)
|
|
@@ -6,17 +6,23 @@
|
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
#
|
|
8
8
|
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
9
11
|
import sys
|
|
10
12
|
from typing import Callable
|
|
11
13
|
|
|
12
14
|
from omegaconf import DictConfig, OmegaConf
|
|
13
15
|
|
|
14
16
|
import lightly_train
|
|
17
|
+
from lightly_train import _logging
|
|
15
18
|
from lightly_train._commands import embed, export, extract_video_frames, train
|
|
16
19
|
from lightly_train._commands.train import TrainConfig
|
|
20
|
+
from lightly_train._logging import LIGHTLY_TRAIN_LOG_LEVEL_ENV_VAR
|
|
17
21
|
from lightly_train._models import package_helpers
|
|
18
22
|
from lightly_train.errors import ConfigError
|
|
19
23
|
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
20
26
|
_HELP_COMMANDS = {"help", "--help", "-h"}
|
|
21
27
|
_HELP_MSG = """
|
|
22
28
|
Commands:
|
|
@@ -29,6 +35,9 @@ _HELP_MSG = """
|
|
|
29
35
|
lightly-train help Show help message.
|
|
30
36
|
|
|
31
37
|
Run `lightly_train <command> help` for more information on a specific command.
|
|
38
|
+
|
|
39
|
+
Optional arguments:
|
|
40
|
+
-v, --verbose Run the command in verbose mode for detailed output.
|
|
32
41
|
"""
|
|
33
42
|
|
|
34
43
|
_train_cfg = TrainConfig()
|
|
@@ -115,6 +124,9 @@ _TRAIN_HELP_MSG = f"""
|
|
|
115
124
|
Additional arguments for the model. The available arguments depend on the
|
|
116
125
|
the `model` parameter.
|
|
117
126
|
|
|
127
|
+
Optional arguments:
|
|
128
|
+
-v, --verbose Run the command in verbose mode for detailed output.
|
|
129
|
+
|
|
118
130
|
Examples:
|
|
119
131
|
# Train a ResNet-18 model with SimCLR on ImageNet
|
|
120
132
|
lightly-train train out=out data=imagenet/train model=torchvision/resnet18 method=simclr
|
|
@@ -153,6 +165,9 @@ _EXPORT_HELP_MSG = """
|
|
|
153
165
|
and can be used to load the model with different lightly_train versions but
|
|
154
166
|
requires the model to already be instantiated.
|
|
155
167
|
|
|
168
|
+
Optional arguments:
|
|
169
|
+
-v, --verbose Run the command in verbose mode for detailed output.
|
|
170
|
+
|
|
156
171
|
Examples:
|
|
157
172
|
# Export the state dict of the model
|
|
158
173
|
lightly-train export checkpoint=out/checkpoints/last.ckpt out=out/model.pth \\
|
|
@@ -201,6 +216,9 @@ _EMBED_HELP_MSG = """
|
|
|
201
216
|
overwrite (bool):
|
|
202
217
|
Overwrite the output file if it already exists.
|
|
203
218
|
|
|
219
|
+
Optional arguments:
|
|
220
|
+
-v, --verbose Run the command in verbose mode for detailed output.
|
|
221
|
+
|
|
204
222
|
Examples:
|
|
205
223
|
# Embed images from a model checkpoint
|
|
206
224
|
lightly-train embed out=embeddings.csv data=images checkpoint=out/checkpoints/last.ckpt \\
|
|
@@ -245,6 +263,9 @@ _EXTRACT_VIDEO_FRAMES_HELP_MSG = f"""
|
|
|
245
263
|
Number of parallel calls to ffmpeg when processing multiple videos.
|
|
246
264
|
If None, the number of workers is set to the number of available CPU cores.
|
|
247
265
|
|
|
266
|
+
Optional arguments:
|
|
267
|
+
-v, --verbose Run the command in verbose mode for detailed output.
|
|
268
|
+
|
|
248
269
|
Examples:
|
|
249
270
|
# Extract frames from videos
|
|
250
271
|
lightly-train extract_video_frames data=videos out=frames
|
|
@@ -257,12 +278,25 @@ _EXTRACT_VIDEO_FRAMES_HELP_MSG = f"""
|
|
|
257
278
|
"""
|
|
258
279
|
|
|
259
280
|
|
|
281
|
+
_VERBOSE_FLAGS = ["-v", "--verbose"]
|
|
282
|
+
|
|
283
|
+
|
|
260
284
|
def cli(config: DictConfig) -> None:
|
|
285
|
+
keys = list(config.keys())
|
|
286
|
+
|
|
287
|
+
# Check if the user wants to run the command in verbose mode.
|
|
288
|
+
# Any of the following will enable verbose mode: -v, --verbose
|
|
289
|
+
if any(flag in keys for flag in _VERBOSE_FLAGS):
|
|
290
|
+
os.environ[LIGHTLY_TRAIN_LOG_LEVEL_ENV_VAR] = str(logging.DEBUG)
|
|
291
|
+
config = OmegaConf.create(
|
|
292
|
+
{k: v for k, v in config.items() if k not in _VERBOSE_FLAGS}
|
|
293
|
+
)
|
|
294
|
+
_logging.set_up_console_logging()
|
|
295
|
+
|
|
261
296
|
if config.is_empty():
|
|
262
297
|
_show_help()
|
|
263
298
|
return
|
|
264
299
|
|
|
265
|
-
keys = list(config.keys())
|
|
266
300
|
# First argument after lightly_train is the command. For example `lightly-train train ...`
|
|
267
301
|
command = str(keys[0]).lower()
|
|
268
302
|
help_if_config_empty = True
|
|
@@ -336,17 +370,21 @@ def _run_command_fn(
|
|
|
336
370
|
try:
|
|
337
371
|
command_fn(config)
|
|
338
372
|
except ConfigError as ex:
|
|
373
|
+
logger.error(ex)
|
|
374
|
+
raise ex from None # Shorten stacktrace
|
|
375
|
+
except Exception as ex:
|
|
376
|
+
logger.error(ex)
|
|
339
377
|
raise ex from None # Shorten stacktrace
|
|
340
378
|
|
|
341
379
|
|
|
342
380
|
def _list_models(config: DictConfig) -> None:
|
|
343
381
|
lines = [f" {model}" for model in package_helpers.list_model_names()]
|
|
344
|
-
|
|
382
|
+
logger.info("\n".join(lines))
|
|
345
383
|
|
|
346
384
|
|
|
347
385
|
def _list_methods(config: DictConfig) -> None:
|
|
348
386
|
lines = [f" {method}" for method in lightly_train.list_methods()]
|
|
349
|
-
|
|
387
|
+
logger.info("\n".join(lines))
|
|
350
388
|
|
|
351
389
|
|
|
352
390
|
def _is_help_command_in_config(config: DictConfig) -> bool:
|
|
@@ -370,7 +408,7 @@ def _show_invalid_command_help(command: str) -> None:
|
|
|
370
408
|
|
|
371
409
|
|
|
372
410
|
def _show_msg(msg: str) -> None:
|
|
373
|
-
|
|
411
|
+
logger.info(_format_msg(msg))
|
|
374
412
|
|
|
375
413
|
|
|
376
414
|
def _format_msg(msg: str) -> str:
|
|
@@ -37,6 +37,29 @@ def filter_train_warnings() -> None:
|
|
|
37
37
|
"in this directory can be modified when the new ones are saved!"
|
|
38
38
|
),
|
|
39
39
|
)
|
|
40
|
+
# Ignore mixed precision CUDA warnings as the information that CUDA is not available
|
|
41
|
+
# can be found elsewhere. The same warnings don't pop up for full precision training.
|
|
42
|
+
warnings.filterwarnings(
|
|
43
|
+
"ignore",
|
|
44
|
+
message="User provided device_type of 'cuda', but CUDA is not available.",
|
|
45
|
+
module="torch.amp",
|
|
46
|
+
category=UserWarning,
|
|
47
|
+
)
|
|
48
|
+
warnings.filterwarnings(
|
|
49
|
+
"ignore",
|
|
50
|
+
message="torch.cuda.amp.GradScaler is enabled, but CUDA is not available.",
|
|
51
|
+
module="torch.amp",
|
|
52
|
+
category=UserWarning,
|
|
53
|
+
)
|
|
54
|
+
# Ignore `lr_scheduler.step()` before `optimizer.step()` warning as it's a PyTorch Lightning issue.
|
|
55
|
+
# See https://github.com/Lightning-AI/pytorch-lightning/issues/5558
|
|
56
|
+
# TODO(Philipp, 09/24): Remove this once the issue is resolved.
|
|
57
|
+
warnings.filterwarnings(
|
|
58
|
+
"ignore",
|
|
59
|
+
message="Detected call of \`lr_scheduler.step\(\)\` before \`optimizer.step\(\)\`",
|
|
60
|
+
category=UserWarning,
|
|
61
|
+
module="torch.optim.lr_scheduler",
|
|
62
|
+
)
|
|
40
63
|
|
|
41
64
|
|
|
42
65
|
def filter_embed_warnings() -> None:
|
|
@@ -7,20 +7,25 @@
|
|
|
7
7
|
#
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import
|
|
10
|
+
import logging
|
|
11
|
+
import pprint
|
|
11
12
|
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
12
14
|
|
|
13
|
-
from omegaconf import MISSING
|
|
14
15
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
|
15
16
|
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
|
16
17
|
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
|
17
18
|
from pytorch_lightning.accelerators.mps import MPSAccelerator
|
|
19
|
+
from torch.nn import Module
|
|
18
20
|
|
|
19
21
|
from lightly_train.types import PathLike
|
|
20
22
|
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
def get_checkpoint_path(checkpoint: PathLike) -> Path:
|
|
23
27
|
checkpoint_path = Path(checkpoint).resolve()
|
|
28
|
+
logger.debug(f"Making sure checkpoint '{checkpoint_path}' exists.")
|
|
24
29
|
if not checkpoint_path.exists():
|
|
25
30
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' does not exist!")
|
|
26
31
|
if not checkpoint_path.is_file():
|
|
@@ -30,6 +35,7 @@ def get_checkpoint_path(checkpoint: PathLike) -> Path:
|
|
|
30
35
|
|
|
31
36
|
def get_out_path(out: PathLike, overwrite: bool) -> Path:
|
|
32
37
|
out_path = Path(out).resolve()
|
|
38
|
+
logger.debug(f"Checking if output path '{out_path}' exists.")
|
|
33
39
|
if out_path.exists():
|
|
34
40
|
if not overwrite:
|
|
35
41
|
raise ValueError(
|
|
@@ -44,37 +50,37 @@ def get_out_path(out: PathLike, overwrite: bool) -> Path:
|
|
|
44
50
|
def get_accelerator(
|
|
45
51
|
accelerator: str | Accelerator,
|
|
46
52
|
) -> str | Accelerator:
|
|
53
|
+
logger.debug(f"Getting accelerator for '{accelerator}'.")
|
|
47
54
|
if accelerator != "auto":
|
|
48
55
|
# User specified an accelerator, return it.
|
|
49
56
|
return accelerator
|
|
50
57
|
|
|
51
58
|
# Default to CUDA if available.
|
|
52
59
|
if CUDAAccelerator.is_available():
|
|
60
|
+
logger.debug("CUDA is available, defaulting to CUDA.")
|
|
53
61
|
return CUDAAccelerator()
|
|
54
62
|
elif MPSAccelerator.is_available():
|
|
63
|
+
logger.debug("MPS is available, defaulting to MPS.")
|
|
55
64
|
return MPSAccelerator()
|
|
56
65
|
else:
|
|
66
|
+
logger.debug("CUDA and MPS are not available, defaulting to CPU.")
|
|
57
67
|
return CPUAccelerator()
|
|
58
68
|
|
|
59
69
|
|
|
60
|
-
def get_default_out() -> str:
|
|
61
|
-
return (
|
|
62
|
-
"/out" if os.getenv("LIGHTLY_TRAIN_IS_DOCKER", "False") == "True" else MISSING
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def get_default_data() -> str:
|
|
67
|
-
return (
|
|
68
|
-
"/data" if os.getenv("LIGHTLY_TRAIN_IS_DOCKER", "False") == "True" else MISSING
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
|
|
72
70
|
def get_out_dir(out: PathLike, resume: bool, overwrite: bool) -> Path:
|
|
73
71
|
out_dir = Path(out).resolve()
|
|
72
|
+
logger.debug(f"Checking if output directory '{out_dir}' exists.")
|
|
74
73
|
if out_dir.exists():
|
|
75
74
|
if not out_dir.is_dir():
|
|
76
75
|
raise ValueError(f"Output '{out_dir}' is not a directory!")
|
|
77
|
-
|
|
76
|
+
|
|
77
|
+
# Ignore the train.log file as it can already exist when using multiple devices.
|
|
78
|
+
# TODO(Guarin, 09/24): Fix this by checking that the directory is completely
|
|
79
|
+
# empty at the beginning. For this we have to take multiple devices and repeat
|
|
80
|
+
# calls to this function into account.
|
|
81
|
+
dir_not_empty = any(
|
|
82
|
+
filepath for filepath in out_dir.iterdir() if filepath.name != "train.log"
|
|
83
|
+
)
|
|
78
84
|
if dir_not_empty and not (resume or overwrite):
|
|
79
85
|
raise ValueError(
|
|
80
86
|
f"Output '{out_dir}' is not empty! Set overwrite=True to overwrite the "
|
|
@@ -82,3 +88,12 @@ def get_out_dir(out: PathLike, resume: bool, overwrite: bool) -> Path:
|
|
|
82
88
|
)
|
|
83
89
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
84
90
|
return out_dir
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def pretty_format_args(
|
|
94
|
+
args: dict[str, Any], indent: int = 2, width: int = 200, compact: bool = True
|
|
95
|
+
) -> str:
|
|
96
|
+
if isinstance(args.get("model"), Module):
|
|
97
|
+
args["model"] = args["model"].__class__.__name__
|
|
98
|
+
|
|
99
|
+
return f"Args: {pprint.pformat(args, indent=indent, width=width, compact=compact)}"
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
#
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import
|
|
10
|
+
import logging
|
|
11
11
|
from dataclasses import dataclass
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
|
|
@@ -21,6 +21,7 @@ from pytorch_lightning import Trainer
|
|
|
21
21
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
|
22
22
|
from torch.utils.data import DataLoader, Dataset
|
|
23
23
|
|
|
24
|
+
from lightly_train import _logging
|
|
24
25
|
from lightly_train._checkpoint import Checkpoint
|
|
25
26
|
from lightly_train._commands import _warnings, common_helpers
|
|
26
27
|
from lightly_train._commands.common_helpers import get_checkpoint_path, get_out_path
|
|
@@ -35,6 +36,8 @@ from lightly_train._embedding.writers.embedding_writer import EmbeddingWriter
|
|
|
35
36
|
from lightly_train._models.embedding_model import EmbeddingModel
|
|
36
37
|
from lightly_train.types import PathLike
|
|
37
38
|
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
38
41
|
|
|
39
42
|
def embed(
|
|
40
43
|
out: PathLike,
|
|
@@ -82,7 +85,12 @@ def embed(
|
|
|
82
85
|
overwrite:
|
|
83
86
|
Overwrite the output file if it already exists.
|
|
84
87
|
"""
|
|
88
|
+
# Set up logging.
|
|
85
89
|
_warnings.filter_embed_warnings()
|
|
90
|
+
_logging.set_up_console_logging()
|
|
91
|
+
logger.info(common_helpers.pretty_format_args(args=locals()))
|
|
92
|
+
|
|
93
|
+
logger.info(f"Embedding images in '{data}'.")
|
|
86
94
|
format = _get_format(format=format)
|
|
87
95
|
out_path = get_out_path(out=out, overwrite=overwrite)
|
|
88
96
|
checkpoint_path = get_checkpoint_path(checkpoint=checkpoint)
|
|
@@ -103,9 +111,11 @@ def embed(
|
|
|
103
111
|
dataloaders=dataloader,
|
|
104
112
|
return_predictions=False,
|
|
105
113
|
)
|
|
114
|
+
logger.info(f"Embeddings saved to '{out_path}'.")
|
|
106
115
|
|
|
107
116
|
|
|
108
117
|
def embed_from_config(config: DictConfig) -> None:
|
|
118
|
+
logger.debug(f"Embedding images with config: {config}")
|
|
109
119
|
config = _parse_config(config=config)
|
|
110
120
|
config = _validate_config(config=config)
|
|
111
121
|
config_dict = _config_to_dict(config=config)
|
|
@@ -114,8 +124,8 @@ def embed_from_config(config: DictConfig) -> None:
|
|
|
114
124
|
|
|
115
125
|
@dataclass
|
|
116
126
|
class EmbedConfig(Config):
|
|
117
|
-
out: str =
|
|
118
|
-
data: str =
|
|
127
|
+
out: str = MISSING
|
|
128
|
+
data: str = MISSING
|
|
119
129
|
checkpoint: str = MISSING
|
|
120
130
|
format: str = MISSING
|
|
121
131
|
# OmegaConf doesn't support Union[int, Tuple[int, int]] so we have to use the
|
|
@@ -128,6 +138,7 @@ class EmbedConfig(Config):
|
|
|
128
138
|
|
|
129
139
|
|
|
130
140
|
def _get_format(format: EmbeddingFormat | str) -> EmbeddingFormat:
|
|
141
|
+
logger.debug(f"Getting embedding format for '{format}'.")
|
|
131
142
|
try:
|
|
132
143
|
return EmbeddingFormat(format)
|
|
133
144
|
except ValueError:
|
|
@@ -140,8 +151,10 @@ def _get_format(format: EmbeddingFormat | str) -> EmbeddingFormat:
|
|
|
140
151
|
def _get_transform(
|
|
141
152
|
image_size: int | tuple[int, int],
|
|
142
153
|
) -> EmbeddingTransform:
|
|
154
|
+
logger.debug(f"Getting embedding transform for image size {image_size}.")
|
|
143
155
|
mean = tuple(IMAGENET_NORMALIZE["mean"])
|
|
144
156
|
std = tuple(IMAGENET_NORMALIZE["std"])
|
|
157
|
+
logger.debug(f"Using mean {mean} and std {std} for normalization.")
|
|
145
158
|
assert len(mean) == len(std) == 3
|
|
146
159
|
if isinstance(image_size, int):
|
|
147
160
|
image_size = (image_size, image_size)
|
|
@@ -159,8 +172,10 @@ def _get_dataset(
|
|
|
159
172
|
transform: EmbeddingTransform,
|
|
160
173
|
) -> Dataset:
|
|
161
174
|
if isinstance(data, Dataset):
|
|
175
|
+
logger.debug("Using provided dataset.")
|
|
162
176
|
return data
|
|
163
177
|
else:
|
|
178
|
+
logger.debug(f"Loading LightlyDataset from '{data}'.")
|
|
164
179
|
return LightlyDataset(input_dir=str(data), transform=transform)
|
|
165
180
|
|
|
166
181
|
|
|
@@ -169,12 +184,15 @@ def _get_dataloader(
|
|
|
169
184
|
batch_size: int,
|
|
170
185
|
num_workers: int,
|
|
171
186
|
) -> DataLoader:
|
|
187
|
+
logger.debug(
|
|
188
|
+
f"Getting dataloader with batch_size {batch_size} and num_workers {num_workers}."
|
|
189
|
+
)
|
|
172
190
|
if isinstance(dataset, Sized):
|
|
173
191
|
dataset_size = len(dataset)
|
|
174
192
|
if batch_size > dataset_size:
|
|
175
193
|
old_batch_size = batch_size
|
|
176
194
|
batch_size = dataset_size
|
|
177
|
-
|
|
195
|
+
logger.warning(
|
|
178
196
|
f"Detected dataset size {dataset_size} and batch size "
|
|
179
197
|
f"{old_batch_size}. Reducing batch size to {batch_size}."
|
|
180
198
|
)
|
|
@@ -190,11 +208,13 @@ def _get_dataloader(
|
|
|
190
208
|
|
|
191
209
|
|
|
192
210
|
def _get_embedding_model(checkpoint_path: Path) -> EmbeddingModel:
|
|
211
|
+
logger.debug(f"Loading embedding model from '{checkpoint_path}'.")
|
|
193
212
|
checkpoint = Checkpoint.from_path(checkpoint=checkpoint_path)
|
|
194
213
|
return checkpoint.lightly_train.models.embedding_model
|
|
195
214
|
|
|
196
215
|
|
|
197
216
|
def _get_trainer(accelerator: str | Accelerator, writer: EmbeddingWriter) -> Trainer:
|
|
217
|
+
logger.debug(f"Getting trainer with accelerator '{accelerator}'.")
|
|
198
218
|
return Trainer(
|
|
199
219
|
accelerator=accelerator,
|
|
200
220
|
devices=1,
|