lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""API endpoints for annotation tasks."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
7
|
+
from sqlmodel import Session
|
|
8
|
+
|
|
9
|
+
from lightly_studio.api.db import get_session
|
|
10
|
+
from lightly_studio.models.annotation_task import AnnotationTaskTable
|
|
11
|
+
from lightly_studio.resolvers import annotation_task_resolver
|
|
12
|
+
|
|
13
|
+
router = APIRouter(prefix="/annotationtasks", tags=["annotationtasks"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@router.get("/", response_model=List[AnnotationTaskTable])
|
|
17
|
+
def get_annotation_tasks(
|
|
18
|
+
session: Session = Depends(get_session), # noqa: B008
|
|
19
|
+
) -> List[AnnotationTaskTable]:
|
|
20
|
+
"""Get all annotation tasks."""
|
|
21
|
+
return annotation_task_resolver.get_all(session=session)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@router.get("/{annotation_task_id}", response_model=AnnotationTaskTable)
|
|
25
|
+
def get_annotation_task(
|
|
26
|
+
annotation_task_id: UUID,
|
|
27
|
+
session: Session = Depends(get_session), # noqa: B008
|
|
28
|
+
) -> AnnotationTaskTable:
|
|
29
|
+
"""Get an annotation task by ID."""
|
|
30
|
+
task = annotation_task_resolver.get_by_id(
|
|
31
|
+
session=session, annotation_task_id=annotation_task_id
|
|
32
|
+
)
|
|
33
|
+
if task is None:
|
|
34
|
+
raise HTTPException(
|
|
35
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
36
|
+
detail=f"Annotation task with ID {annotation_task_id} not found",
|
|
37
|
+
)
|
|
38
|
+
return task
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""This module contains the API routes for managing classifiers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Depends, UploadFile
|
|
10
|
+
from fastapi.responses import StreamingResponse
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
from sqlmodel import Session
|
|
13
|
+
from typing_extensions import Annotated
|
|
14
|
+
|
|
15
|
+
from lightly_studio.api.db import get_session
|
|
16
|
+
from lightly_studio.few_shot_classifier.classifier import (
|
|
17
|
+
ExportType,
|
|
18
|
+
)
|
|
19
|
+
from lightly_studio.few_shot_classifier.classifier_manager import (
|
|
20
|
+
ClassifierManagerProvider,
|
|
21
|
+
)
|
|
22
|
+
from lightly_studio.models.classifier import EmbeddingClassifier
|
|
23
|
+
|
|
24
|
+
classifier_router = APIRouter()
|
|
25
|
+
SessionDep = Annotated[Session, Depends(get_session)]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GetNegativeSamplesRequest(BaseModel):
|
|
29
|
+
"""Request for getting negative samples for classifier training."""
|
|
30
|
+
|
|
31
|
+
positive_sample_ids: list[UUID]
|
|
32
|
+
dataset_id: UUID
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GetNegativeSamplesResponse(BaseModel):
|
|
36
|
+
"""Response for getting negative samples for classifier training."""
|
|
37
|
+
|
|
38
|
+
negative_sample_ids: list[UUID]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@classifier_router.post("/classifiers/get_negative_samples")
|
|
42
|
+
def get_negative_samples(
|
|
43
|
+
request: GetNegativeSamplesRequest, session: SessionDep
|
|
44
|
+
) -> GetNegativeSamplesResponse:
|
|
45
|
+
"""Get negative samples for classifier training.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
request: The request containing negative sample parameters.
|
|
49
|
+
session: Database session.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The response containing negative sample IDs.
|
|
53
|
+
"""
|
|
54
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
55
|
+
negative_samples = classifier_manager.provide_negative_samples(
|
|
56
|
+
session=session,
|
|
57
|
+
dataset_id=request.dataset_id,
|
|
58
|
+
selected_samples=request.positive_sample_ids,
|
|
59
|
+
)
|
|
60
|
+
# Extract just the sample IDs from the returned Sample objects
|
|
61
|
+
negative_sample_ids = [sample.sample_id for sample in negative_samples]
|
|
62
|
+
return GetNegativeSamplesResponse(negative_sample_ids=negative_sample_ids)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class SamplesToRefineResponse(BaseModel):
|
|
66
|
+
"""Response for samples for classifier refinement.
|
|
67
|
+
|
|
68
|
+
Maps class names to lists of sample IDs. First class gets high confidence
|
|
69
|
+
samples, second class gets low confidence samples.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
samples: dict[str, list[UUID]]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@classifier_router.get("/classifiers/{classifier_id}/samples_to_refine")
|
|
76
|
+
def samples_to_refine(
|
|
77
|
+
classifier_id: UUID,
|
|
78
|
+
dataset_id: UUID,
|
|
79
|
+
session: SessionDep,
|
|
80
|
+
) -> SamplesToRefineResponse:
|
|
81
|
+
"""Get samples for classifier refinement.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
classifier_id: The ID of the classifier.
|
|
85
|
+
dataset_id: The ID of the dataset.
|
|
86
|
+
session: Database session.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The response containing sample IDs for refinement.
|
|
90
|
+
"""
|
|
91
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
92
|
+
samples = classifier_manager.get_samples_for_fine_tuning(
|
|
93
|
+
session=session, classifier_id=classifier_id, dataset_id=dataset_id
|
|
94
|
+
)
|
|
95
|
+
return SamplesToRefineResponse(samples=samples)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@classifier_router.get("/classifiers/{classifier_id}/sample_history")
|
|
99
|
+
def sample_history(
|
|
100
|
+
classifier_id: UUID,
|
|
101
|
+
) -> SamplesToRefineResponse:
|
|
102
|
+
"""Get all samples used in the classifier training.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
classifier_id: The ID of the classifier.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The response containing sample IDs used in the training.
|
|
109
|
+
"""
|
|
110
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
111
|
+
samples = classifier_manager.get_annotations(classifier_id=classifier_id)
|
|
112
|
+
return SamplesToRefineResponse(samples=samples)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@classifier_router.post(
|
|
116
|
+
"/classifiers/{classifier_id}/commit_temp_classifier",
|
|
117
|
+
)
|
|
118
|
+
def commit_temp_classifier(
|
|
119
|
+
classifier_id: UUID,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Commit the classifier.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
classifier_id: The ID of the classifier.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
None
|
|
128
|
+
"""
|
|
129
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
130
|
+
classifier_manager.commit_temp_classifier(classifier_id=classifier_id)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@classifier_router.delete(
|
|
134
|
+
"/classifiers/{classifier_id}/drop_temp_classifier",
|
|
135
|
+
)
|
|
136
|
+
def drop_temp_classifier(
|
|
137
|
+
classifier_id: UUID,
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Drop the classifier.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
classifier_id: The ID of the classifier.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
None
|
|
146
|
+
"""
|
|
147
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
148
|
+
classifier_manager.drop_temp_classifier(classifier_id=classifier_id)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class SaveClassifierRequest(BaseModel):
|
|
152
|
+
"""Request for saving classifier to a file."""
|
|
153
|
+
|
|
154
|
+
file_path: str
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@classifier_router.post(
|
|
158
|
+
"/classifiers/{classifier_id}/save_classifier_to_file/{export_type}",
|
|
159
|
+
)
|
|
160
|
+
def save_classifier_to_file(
|
|
161
|
+
classifier_id: UUID,
|
|
162
|
+
export_type: ExportType,
|
|
163
|
+
) -> StreamingResponse:
|
|
164
|
+
"""Save the classifier to a file.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
classifier_id: The ID of the classifier.
|
|
168
|
+
export_type: The type of export (e.g., "sklearn", "lightly").
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
StreamingResponse containing the pickled classifier file.
|
|
172
|
+
"""
|
|
173
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
174
|
+
# Use BytesIO to capture the file content and send it as a response.
|
|
175
|
+
buffer = io.BytesIO()
|
|
176
|
+
classifier_manager.save_classifier_to_buffer(
|
|
177
|
+
classifier_id=classifier_id, buffer=buffer, export_type=export_type
|
|
178
|
+
)
|
|
179
|
+
buffer.seek(0)
|
|
180
|
+
|
|
181
|
+
# Get classifier name for the filename
|
|
182
|
+
classifier = classifier_manager.get_classifier_by_id(classifier_id=classifier_id)
|
|
183
|
+
filename = f"{classifier.classifier_name}.pkl"
|
|
184
|
+
headers = {
|
|
185
|
+
"Content-Disposition": f'attachment; filename="{filename}"',
|
|
186
|
+
"Content-Type": "application/octet-stream",
|
|
187
|
+
"Access-Control-Expose-Headers": "Content-Disposition",
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
return StreamingResponse(buffer, headers=headers, media_type="application/octet-stream")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class LoadClassifierRequest(BaseModel):
|
|
194
|
+
"""Request for loading classifier from a file."""
|
|
195
|
+
|
|
196
|
+
file_path: str
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class LoadClassifierResponse(BaseModel):
|
|
200
|
+
"""Response for loading classifier from a file."""
|
|
201
|
+
|
|
202
|
+
classifier_id: UUID
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@classifier_router.post(
|
|
206
|
+
"/classifiers/load_classifier_from_file",
|
|
207
|
+
)
|
|
208
|
+
def load_classifier_from_file(
|
|
209
|
+
request: LoadClassifierRequest,
|
|
210
|
+
session: SessionDep,
|
|
211
|
+
) -> LoadClassifierResponse:
|
|
212
|
+
"""Load the classifier from a file.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
request: The request containing the file path.
|
|
216
|
+
session: Database session.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Response with the ID of the loaded classifier.
|
|
220
|
+
"""
|
|
221
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
222
|
+
classifier = classifier_manager.load_classifier_from_file(
|
|
223
|
+
session=session, file_path=Path(request.file_path)
|
|
224
|
+
)
|
|
225
|
+
return LoadClassifierResponse(classifier_id=classifier.classifier_id)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@classifier_router.post(
|
|
229
|
+
"/classifiers/load_classifier_from_buffer",
|
|
230
|
+
)
|
|
231
|
+
def load_classifier_from_buffer(
|
|
232
|
+
file: UploadFile,
|
|
233
|
+
session: SessionDep,
|
|
234
|
+
) -> UUID:
|
|
235
|
+
"""Load a classifier from an uploaded file buffer.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
file: The uploaded classifier file.
|
|
239
|
+
session: Database session.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The ID of the loaded classifier.
|
|
243
|
+
"""
|
|
244
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
245
|
+
|
|
246
|
+
# Read file into buffer
|
|
247
|
+
buffer = io.BytesIO(file.file.read())
|
|
248
|
+
|
|
249
|
+
# Load classifier from buffer
|
|
250
|
+
classifier = classifier_manager.load_classifier_from_buffer(session=session, buffer=buffer)
|
|
251
|
+
return classifier.classifier_id
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@classifier_router.post(
|
|
255
|
+
"/classifiers/{classifier_id}/train_classifier",
|
|
256
|
+
)
|
|
257
|
+
def train_classifier(
|
|
258
|
+
classifier_id: UUID,
|
|
259
|
+
session: SessionDep,
|
|
260
|
+
) -> None:
|
|
261
|
+
"""Train the classifier.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
classifier_id: The ID of the classifier.
|
|
265
|
+
session: Database session.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
None
|
|
269
|
+
"""
|
|
270
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
271
|
+
classifier_manager.train_classifier(session=session, classifier_id=classifier_id)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class UpdateAnnotationsRequest(BaseModel):
|
|
275
|
+
"""Request for updating classifier annotations."""
|
|
276
|
+
|
|
277
|
+
annotations: dict[str, list[UUID]]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@classifier_router.post(
|
|
281
|
+
"/classifiers/{classifier_id}/update_annotations",
|
|
282
|
+
)
|
|
283
|
+
def update_classifiers_annotations(
|
|
284
|
+
classifier_id: UUID,
|
|
285
|
+
request: UpdateAnnotationsRequest,
|
|
286
|
+
) -> None:
|
|
287
|
+
"""Update the annotations for a classifier.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
classifier_id: The ID of the classifier.
|
|
291
|
+
request: The request containing the new annotations.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
None
|
|
295
|
+
|
|
296
|
+
"""
|
|
297
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
298
|
+
classifier_manager.update_classifiers_annotations(
|
|
299
|
+
classifier_id=classifier_id,
|
|
300
|
+
new_annotations=request.annotations,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class CreateClassifierRequest(BaseModel):
|
|
305
|
+
"""Request model for creating a classifier."""
|
|
306
|
+
|
|
307
|
+
name: str
|
|
308
|
+
class_list: list[str]
|
|
309
|
+
dataset_id: UUID
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class CreateClassifierResponse(BaseModel):
|
|
313
|
+
"""Response model for creating a classifier."""
|
|
314
|
+
|
|
315
|
+
name: str
|
|
316
|
+
classifier_id: str
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@classifier_router.post("/classifiers/create")
|
|
320
|
+
def create_classifier(
|
|
321
|
+
request: CreateClassifierRequest, session: SessionDep
|
|
322
|
+
) -> CreateClassifierResponse:
|
|
323
|
+
"""Create a new classifier.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
request: The request containing classifier creation parameters.
|
|
327
|
+
session: Database session.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
Response with the name and ID of the classifier.
|
|
331
|
+
|
|
332
|
+
"""
|
|
333
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
334
|
+
classifier = classifier_manager.create_classifier(
|
|
335
|
+
session=session,
|
|
336
|
+
name=request.name,
|
|
337
|
+
class_list=request.class_list,
|
|
338
|
+
dataset_id=request.dataset_id,
|
|
339
|
+
)
|
|
340
|
+
return CreateClassifierResponse(
|
|
341
|
+
name=classifier.few_shot_classifier.name,
|
|
342
|
+
classifier_id=str(classifier.classifier_id),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class GetAllClassifiersResponse(BaseModel):
|
|
347
|
+
"""Response model for getting all active classifiers."""
|
|
348
|
+
|
|
349
|
+
classifiers: list[EmbeddingClassifier]
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@classifier_router.get("/classifiers/get_all_classifiers")
|
|
353
|
+
def get_all_classifiers() -> GetAllClassifiersResponse:
|
|
354
|
+
"""Get all active classifiers.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Response with list of tuples containing classifier names and IDs.
|
|
358
|
+
"""
|
|
359
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
360
|
+
classifiers = classifier_manager.get_all_classifiers()
|
|
361
|
+
return GetAllClassifiersResponse(classifiers=classifiers)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@classifier_router.post(
|
|
365
|
+
"/classifiers/{classifier_id}/run_on_dataset/{dataset_id}",
|
|
366
|
+
)
|
|
367
|
+
def run_classifier_route(
|
|
368
|
+
classifier_id: UUID,
|
|
369
|
+
dataset_id: UUID,
|
|
370
|
+
session: SessionDep,
|
|
371
|
+
) -> None:
|
|
372
|
+
"""Run the classifier on a dataset.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
dataset_id: The ID of the dataset to run the classifier on.
|
|
376
|
+
classifier_id: The ID of the classifier.
|
|
377
|
+
session: Database session.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
None
|
|
381
|
+
"""
|
|
382
|
+
classifier_manager = ClassifierManagerProvider.get_classifier_manager()
|
|
383
|
+
classifier_manager.run_classifier(
|
|
384
|
+
session=session,
|
|
385
|
+
classifier_id=classifier_id,
|
|
386
|
+
dataset_id=dataset_id,
|
|
387
|
+
)
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""This module contains the API routes for managing datasets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import List
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
10
|
+
from fastapi.responses import PlainTextResponse
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
from sqlmodel import Field, Session
|
|
13
|
+
from typing_extensions import Annotated
|
|
14
|
+
|
|
15
|
+
from lightly_studio.api.db import get_session
|
|
16
|
+
from lightly_studio.api.routes.api.status import HTTP_STATUS_NOT_FOUND
|
|
17
|
+
from lightly_studio.api.routes.api.validators import Paginated
|
|
18
|
+
from lightly_studio.models.dataset import (
|
|
19
|
+
DatasetCreate,
|
|
20
|
+
DatasetTable,
|
|
21
|
+
DatasetView,
|
|
22
|
+
)
|
|
23
|
+
from lightly_studio.resolvers import dataset_resolver
|
|
24
|
+
from lightly_studio.resolvers.dataset_resolver import (
|
|
25
|
+
ExportFilter,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
dataset_router = APIRouter()
|
|
29
|
+
SessionDep = Annotated[Session, Depends(get_session)]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_and_validate_dataset_id(
|
|
33
|
+
session: SessionDep,
|
|
34
|
+
dataset_id: UUID,
|
|
35
|
+
) -> DatasetTable:
|
|
36
|
+
"""Get and validate the existence of a dataset on a route."""
|
|
37
|
+
dataset = dataset_resolver.get_by_id(session=session, dataset_id=dataset_id)
|
|
38
|
+
if not dataset:
|
|
39
|
+
raise HTTPException(
|
|
40
|
+
status_code=HTTP_STATUS_NOT_FOUND,
|
|
41
|
+
detail=f""" Dataset with {dataset_id} not found.""",
|
|
42
|
+
)
|
|
43
|
+
return dataset
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataset_router.post(
|
|
47
|
+
"/datasets",
|
|
48
|
+
response_model=DatasetView,
|
|
49
|
+
status_code=201,
|
|
50
|
+
)
|
|
51
|
+
def create_dataset(
|
|
52
|
+
dataset_input: DatasetCreate,
|
|
53
|
+
session: SessionDep,
|
|
54
|
+
) -> DatasetTable:
|
|
55
|
+
"""Create a new dataset in the database."""
|
|
56
|
+
return dataset_resolver.create(session=session, dataset=dataset_input)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataset_router.get("/datasets", response_model=List[DatasetView])
|
|
60
|
+
def read_datasets(
|
|
61
|
+
session: SessionDep,
|
|
62
|
+
paginated: Annotated[Paginated, Query()],
|
|
63
|
+
) -> list[DatasetTable]:
|
|
64
|
+
"""Retrieve a list of datasets from the database."""
|
|
65
|
+
return dataset_resolver.get_all(session=session, offset=paginated.offset, limit=paginated.limit)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataset_router.get("/datasets/{dataset_id}")
|
|
69
|
+
def read_dataset(
|
|
70
|
+
dataset: Annotated[
|
|
71
|
+
DatasetTable,
|
|
72
|
+
Path(title="Dataset Id"),
|
|
73
|
+
Depends(get_and_validate_dataset_id),
|
|
74
|
+
],
|
|
75
|
+
) -> DatasetTable:
|
|
76
|
+
"""Retrieve a single dataset from the database."""
|
|
77
|
+
return dataset
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataset_router.put("/datasets/{dataset_id}")
|
|
81
|
+
def update_dataset(
|
|
82
|
+
session: SessionDep,
|
|
83
|
+
dataset: Annotated[
|
|
84
|
+
DatasetTable,
|
|
85
|
+
Path(title="Dataset Id"),
|
|
86
|
+
Depends(get_and_validate_dataset_id),
|
|
87
|
+
],
|
|
88
|
+
dataset_input: DatasetCreate,
|
|
89
|
+
) -> DatasetTable:
|
|
90
|
+
"""Update an existing dataset in the database."""
|
|
91
|
+
return dataset_resolver.update(
|
|
92
|
+
session=session,
|
|
93
|
+
dataset_id=dataset.dataset_id,
|
|
94
|
+
dataset_data=dataset_input,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataset_router.delete("/datasets/{dataset_id}")
|
|
99
|
+
def delete_dataset(
|
|
100
|
+
session: SessionDep,
|
|
101
|
+
dataset: Annotated[
|
|
102
|
+
DatasetTable,
|
|
103
|
+
Path(title="Dataset Id"),
|
|
104
|
+
Depends(get_and_validate_dataset_id),
|
|
105
|
+
],
|
|
106
|
+
) -> dict[str, str]:
|
|
107
|
+
"""Delete a dataset from the database."""
|
|
108
|
+
dataset_resolver.delete(session=session, dataset_id=dataset.dataset_id)
|
|
109
|
+
return {"status": "deleted"}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ExportBody(BaseModel):
|
|
113
|
+
"""body parameters for including or excluding tag_ids or sample_ids."""
|
|
114
|
+
|
|
115
|
+
include: ExportFilter | None = Field(
|
|
116
|
+
None, description="include filter for sample_ids or tag_ids"
|
|
117
|
+
)
|
|
118
|
+
exclude: ExportFilter | None = Field(
|
|
119
|
+
None, description="exclude filter for sample_ids or tag_ids"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# This endpoint should be a GET, however due to the potential huge size
|
|
124
|
+
# of sample_ids, it is a POST request to avoid URL length limitations.
|
|
125
|
+
# A body with a GET request is supported by fastAPI however it has undefined
|
|
126
|
+
# behavior: https://fastapi.tiangolo.com/tutorial/body/
|
|
127
|
+
@dataset_router.post(
|
|
128
|
+
"/datasets/{dataset_id}/export",
|
|
129
|
+
)
|
|
130
|
+
def export_dataset_to_absolute_paths(
|
|
131
|
+
session: SessionDep,
|
|
132
|
+
dataset: Annotated[
|
|
133
|
+
DatasetTable,
|
|
134
|
+
Path(title="Dataset Id"),
|
|
135
|
+
Depends(get_and_validate_dataset_id),
|
|
136
|
+
],
|
|
137
|
+
body: ExportBody,
|
|
138
|
+
) -> PlainTextResponse:
|
|
139
|
+
"""Export dataset from the database."""
|
|
140
|
+
# export dataset to absolute paths
|
|
141
|
+
exported = dataset_resolver.export(
|
|
142
|
+
session=session,
|
|
143
|
+
dataset_id=dataset.dataset_id,
|
|
144
|
+
include=body.include,
|
|
145
|
+
exclude=body.exclude,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create a response with the exported data
|
|
149
|
+
response = PlainTextResponse("\n".join(exported))
|
|
150
|
+
|
|
151
|
+
# Add the Content-Disposition header to force download
|
|
152
|
+
filename = f"{dataset.name}_exported_{datetime.now(timezone.utc)}.txt"
|
|
153
|
+
response.headers["Access-Control-Expose-Headers"] = "Content-Disposition"
|
|
154
|
+
response.headers["Content-Disposition"] = f"attachment; filename={filename}"
|
|
155
|
+
|
|
156
|
+
return response
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
Endpoint to export samples from a dataset.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataset_router.post(
|
|
165
|
+
"/datasets/{dataset_id}/export/stats",
|
|
166
|
+
)
|
|
167
|
+
def export_dataset_stats(
|
|
168
|
+
session: SessionDep,
|
|
169
|
+
dataset: Annotated[
|
|
170
|
+
DatasetTable,
|
|
171
|
+
Path(title="Dataset Id"),
|
|
172
|
+
Depends(get_and_validate_dataset_id),
|
|
173
|
+
],
|
|
174
|
+
body: ExportBody,
|
|
175
|
+
) -> int:
|
|
176
|
+
"""Get statistics about the export query."""
|
|
177
|
+
return dataset_resolver.get_filtered_samples_count(
|
|
178
|
+
session=session,
|
|
179
|
+
dataset_id=dataset.dataset_id,
|
|
180
|
+
include=body.include,
|
|
181
|
+
exclude=body.exclude,
|
|
182
|
+
)
|