lightly-studio 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +1 -1
- lightly_studio/api/app.py +8 -4
- lightly_studio/api/db_tables.py +0 -3
- lightly_studio/api/routes/api/annotation.py +26 -0
- lightly_studio/api/routes/api/annotations/__init__.py +7 -0
- lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
- lightly_studio/api/routes/api/caption.py +30 -0
- lightly_studio/api/routes/api/dataset.py +3 -5
- lightly_studio/api/routes/api/embeddings2d.py +136 -0
- lightly_studio/api/routes/api/export.py +73 -0
- lightly_studio/api/routes/api/metadata.py +57 -1
- lightly_studio/api/routes/api/selection.py +87 -0
- lightly_studio/core/add_samples.py +138 -9
- lightly_studio/core/dataset.py +174 -63
- lightly_studio/core/dataset_query/dataset_query.py +5 -0
- lightly_studio/dataset/env.py +4 -0
- lightly_studio/dataset/file_utils.py +13 -2
- lightly_studio/dataset/loader.py +2 -62
- lightly_studio/dataset/mobileclip_embedding_generator.py +3 -2
- lightly_studio/db_manager.py +10 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.B3oFNb6O.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/Samples.CIbricz7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.7Ma7YdVg.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{useFeatureFlags.CV-KWLNP.css → _layout.CefECEWA.css} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.2jKMtOWG.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/-DXuGN29.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Ccq4ZD0B.js → B7302SU7.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BeWf8-vJ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bqz7dyEC.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DRZO-E-T.js → CSCQddQS.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CZGpyrcA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CfQ4mGwl.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CiaNZCBa.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cqo0Vpvt.js +417 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cy4fgWTG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5w4xp5l.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DD63uD-T.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQ8aZ1o-.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Df3aMO5B.js → DSxvnAMh.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_JuJOO3.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_ynJAfY.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dafy4oEQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BqBqV92V.js → Dj4O-5se.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmjAI-UV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dug7Bq1S.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dv5BSBQG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzBTnFhV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzX_yyqb.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Frwd2CjB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H4l0JFh9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H60ATh8g.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/qIv1kPyv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/sLqs1uaK.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/u-it74zV.js +96 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BPc0HQPq.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.SNvc2nrm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.5jT7P06o.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.Cdy-7S5q.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.C_uoESTX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DcO8wIAc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.BIldfkxL.js +1012 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.w9g4AcAx.js → 3.BC9z_TWM.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.BBI8KwnD.js → 4.D8X_Ch5n.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.CAXhxJu6.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{6.CrbkRPam.js → 6.DRA5Ru_2.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.WVBsruHQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BuKUrCEN.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CUIn1yCR.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
- lightly_studio/dist_lightly_studio_view_app/index.html +15 -14
- lightly_studio/examples/example.py +4 -0
- lightly_studio/examples/example_coco.py +4 -0
- lightly_studio/examples/example_coco_caption.py +24 -0
- lightly_studio/examples/example_metadata.py +4 -1
- lightly_studio/examples/example_selection.py +4 -0
- lightly_studio/examples/example_split_work.py +4 -0
- lightly_studio/examples/example_yolo.py +4 -0
- lightly_studio/export/export_dataset.py +73 -0
- lightly_studio/export/lightly_studio_label_input.py +120 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +5 -26
- lightly_studio/metadata/compute_typicality.py +67 -0
- lightly_studio/models/annotation/annotation_base.py +11 -12
- lightly_studio/models/caption.py +73 -0
- lightly_studio/models/dataset.py +1 -2
- lightly_studio/models/metadata.py +1 -1
- lightly_studio/models/sample.py +2 -2
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +2 -1
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +15 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +2 -3
- lightly_studio/resolvers/annotation_resolver/create_many.py +3 -3
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +1 -1
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +7 -3
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +19 -1
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +0 -1
- lightly_studio/resolvers/annotations/annotations_filter.py +1 -11
- lightly_studio/resolvers/caption_resolver.py +80 -0
- lightly_studio/resolvers/dataset_resolver.py +4 -7
- lightly_studio/resolvers/metadata_resolver/__init__.py +2 -2
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +3 -3
- lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
- lightly_studio/resolvers/samples_filter.py +18 -10
- lightly_studio/selection/mundig.py +7 -10
- lightly_studio/selection/selection_config.py +4 -1
- lightly_studio/services/annotations_service/__init__.py +8 -0
- lightly_studio/services/annotations_service/create_annotation.py +63 -0
- lightly_studio/services/annotations_service/delete_annotation.py +22 -0
- lightly_studio/type_definitions.py +2 -0
- {lightly_studio-0.3.2.dist-info → lightly_studio-0.3.4.dist-info}/METADATA +231 -41
- {lightly_studio-0.3.2.dist-info → lightly_studio-0.3.4.dist-info}/RECORD +114 -104
- lightly_studio/api/routes/api/annotation_task.py +0 -37
- lightly_studio/api/routes/api/metrics.py +0 -76
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BBm0IWdq.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BNTuXSAe.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/2O287xak.js +0 -3
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/7YNGEs1C.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BBoGk9hq.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRnH9v23.js +0 -92
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bg1Y5eUZ.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C0JiMuYn.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C98Hk3r5.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CG0dMCJi.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cpy-nab_.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Crk-jcvV.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cs31G8Qn.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CsKrY2zA.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cur71c3O.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CzgC3GFB.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D8GZDMNN.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DFRh-Spp.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DcGCxgpH.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkR_EZ_B.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqUGznj_.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KpAtIldw.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/M1Q1F7bw.js +0 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/OH7-C_mc.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/gLNdjSzu.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/i0ZZ4z06.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BI-EA5gL.js +0 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.CcsRl3cZ.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.BbO4Zc3r.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1._I9GR805.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.J2RBFrSr.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.Cmqj25a-.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C45iKJHA.js +0 -6
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.huHuxdiF.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.FomEdhD6.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cb_ADSLk.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CajIG5ce.js +0 -1
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +0 -268
- lightly_studio/models/annotation_task.py +0 -28
- lightly_studio/resolvers/annotation_resolver/create.py +0 -19
- lightly_studio/resolvers/annotation_task_resolver.py +0 -31
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +0 -48
- {lightly_studio-0.3.2.dist-info → lightly_studio-0.3.4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Computes typicality from embeddings."""
|
|
2
|
+
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from lightly_mundig import Typicality # type: ignore[import-untyped]
|
|
6
|
+
from sqlmodel import Session
|
|
7
|
+
|
|
8
|
+
from lightly_studio.dataset.env import LIGHTLY_STUDIO_LICENSE_KEY
|
|
9
|
+
from lightly_studio.resolvers import (
|
|
10
|
+
metadata_resolver,
|
|
11
|
+
sample_embedding_resolver,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
DEFAULT_NUM_NEAREST_NEIGHBORS = 20
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compute_typicality_metadata(
|
|
18
|
+
session: Session,
|
|
19
|
+
dataset_id: UUID,
|
|
20
|
+
embedding_model_id: UUID,
|
|
21
|
+
metadata_name: str = "typicality",
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Computes typicality for each sample in the dataset from embeddings.
|
|
24
|
+
|
|
25
|
+
Typicality is a measure of how representative a sample is of the dataset.
|
|
26
|
+
It is calculated for each sample from its K-nearest neighbors in the
|
|
27
|
+
embedding space.
|
|
28
|
+
|
|
29
|
+
The computed typicality values are stored as metadata for each sample.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
session:
|
|
33
|
+
The database session.
|
|
34
|
+
dataset_id:
|
|
35
|
+
The ID of the dataset for which to compute the typicality.
|
|
36
|
+
embedding_model_id:
|
|
37
|
+
The ID of the embedding model to use for the computation.
|
|
38
|
+
metadata_name:
|
|
39
|
+
The name of the metadata field to store the typicality values in.
|
|
40
|
+
Defaults to "typicality".
|
|
41
|
+
"""
|
|
42
|
+
license_key = LIGHTLY_STUDIO_LICENSE_KEY
|
|
43
|
+
if license_key is None:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"LIGHTLY_STUDIO_LICENSE_KEY environment variable is not set. "
|
|
46
|
+
"Please set it to your LightlyStudio license key."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
samples = sample_embedding_resolver.get_all_by_dataset_id(
|
|
50
|
+
session=session, dataset_id=dataset_id, embedding_model_id=embedding_model_id
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
embeddings = [sample.embedding for sample in samples]
|
|
54
|
+
typicality = Typicality(embeddings=embeddings, token=license_key)
|
|
55
|
+
typicality_values = typicality.calculate_typicality(
|
|
56
|
+
num_nearest_neighbors=DEFAULT_NUM_NEAREST_NEIGHBORS
|
|
57
|
+
)
|
|
58
|
+
assert len(samples) == len(typicality_values), (
|
|
59
|
+
"The number of samples and computed typicality values must match"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
metadata = [
|
|
63
|
+
(sample.sample_id, {metadata_name: typicality})
|
|
64
|
+
for sample, typicality in zip(samples, typicality_values)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
metadata_resolver.bulk_update_metadata(session, metadata)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""This module defines the base annotation model."""
|
|
2
2
|
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import TYPE_CHECKING, List, Optional
|
|
5
6
|
from uuid import UUID, uuid4
|
|
6
7
|
|
|
@@ -22,10 +23,6 @@ from lightly_studio.models.annotation.semantic_segmentation import (
|
|
|
22
23
|
SemanticSegmentationAnnotationTable,
|
|
23
24
|
SemanticSegmentationAnnotationView,
|
|
24
25
|
)
|
|
25
|
-
from lightly_studio.models.annotation_task import (
|
|
26
|
-
AnnotationTaskTable,
|
|
27
|
-
AnnotationType,
|
|
28
|
-
)
|
|
29
26
|
|
|
30
27
|
if TYPE_CHECKING:
|
|
31
28
|
from lightly_studio.models.annotation_label import (
|
|
@@ -41,6 +38,15 @@ else:
|
|
|
41
38
|
AnnotationLabelTable = object
|
|
42
39
|
|
|
43
40
|
|
|
41
|
+
class AnnotationType(str, Enum):
|
|
42
|
+
"""The type of annotation task."""
|
|
43
|
+
|
|
44
|
+
CLASSIFICATION = "classification"
|
|
45
|
+
SEMANTIC_SEGMENTATION = "semantic_segmentation"
|
|
46
|
+
INSTANCE_SEGMENTATION = "instance_segmentation"
|
|
47
|
+
OBJECT_DETECTION = "object_detection"
|
|
48
|
+
|
|
49
|
+
|
|
44
50
|
class AnnotationBaseTable(SQLModel, table=True):
|
|
45
51
|
"""Base class for all annotation models."""
|
|
46
52
|
|
|
@@ -51,9 +57,7 @@ class AnnotationBaseTable(SQLModel, table=True):
|
|
|
51
57
|
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
52
58
|
annotation_type: AnnotationType
|
|
53
59
|
annotation_label_id: UUID = Field(foreign_key="annotation_labels.annotation_label_id")
|
|
54
|
-
|
|
55
|
-
foreign_key="annotation_tasks.annotation_task_id",
|
|
56
|
-
)
|
|
60
|
+
|
|
57
61
|
confidence: Optional[float] = None
|
|
58
62
|
dataset_id: UUID = Field(foreign_key="datasets.dataset_id")
|
|
59
63
|
sample_id: UUID = Field(foreign_key="samples.sample_id")
|
|
@@ -61,9 +65,6 @@ class AnnotationBaseTable(SQLModel, table=True):
|
|
|
61
65
|
annotation_label: Mapped["AnnotationLabelTable"] = Relationship(
|
|
62
66
|
sa_relationship_kwargs={"lazy": "select"},
|
|
63
67
|
)
|
|
64
|
-
annotation_task: Mapped["AnnotationTaskTable"] = Relationship(
|
|
65
|
-
sa_relationship_kwargs={"lazy": "select"},
|
|
66
|
-
)
|
|
67
68
|
sample: Mapped[Optional["SampleTable"]] = Relationship(
|
|
68
69
|
sa_relationship_kwargs={"lazy": "select"},
|
|
69
70
|
)
|
|
@@ -101,7 +102,6 @@ class AnnotationCreate(SQLModel):
|
|
|
101
102
|
""" Required properties for all annotations. """
|
|
102
103
|
annotation_label_id: UUID
|
|
103
104
|
annotation_type: AnnotationType
|
|
104
|
-
annotation_task_id: UUID
|
|
105
105
|
confidence: Optional[float] = None
|
|
106
106
|
dataset_id: UUID
|
|
107
107
|
sample_id: UUID
|
|
@@ -140,7 +140,6 @@ class AnnotationView(SQLModel):
|
|
|
140
140
|
annotation_id: UUID
|
|
141
141
|
annotation_type: AnnotationType
|
|
142
142
|
annotation_label: AnnotationLabel
|
|
143
|
-
annotation_task_id: UUID
|
|
144
143
|
confidence: Optional[float] = None
|
|
145
144
|
|
|
146
145
|
object_detection_details: Optional[ObjectDetectionAnnotationView] = None
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""This module defines the caption model."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
5
|
+
from uuid import UUID, uuid4
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
from pydantic import Field as PydanticField
|
|
9
|
+
from sqlalchemy.orm import Mapped
|
|
10
|
+
from sqlmodel import Field, Relationship, SQLModel
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from lightly_studio.models.sample import SampleTable
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CaptionTable(SQLModel, table=True):
|
|
17
|
+
"""Class for caption model."""
|
|
18
|
+
|
|
19
|
+
__tablename__ = "caption"
|
|
20
|
+
|
|
21
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
|
|
22
|
+
|
|
23
|
+
caption_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
24
|
+
dataset_id: UUID = Field(foreign_key="datasets.dataset_id")
|
|
25
|
+
sample_id: UUID = Field(foreign_key="samples.sample_id")
|
|
26
|
+
|
|
27
|
+
sample: Mapped[Optional["SampleTable"]] = Relationship(
|
|
28
|
+
sa_relationship_kwargs={"lazy": "select"},
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
text: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CaptionCreate(SQLModel):
|
|
35
|
+
"""Input model for creating captions."""
|
|
36
|
+
|
|
37
|
+
dataset_id: UUID
|
|
38
|
+
sample_id: UUID
|
|
39
|
+
text: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CaptionSampleView(SQLModel):
|
|
43
|
+
"""Sample class for caption view."""
|
|
44
|
+
|
|
45
|
+
file_path_abs: str
|
|
46
|
+
file_name: str
|
|
47
|
+
dataset_id: UUID
|
|
48
|
+
sample_id: UUID
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CaptionView(SQLModel):
|
|
52
|
+
"""Response model for caption."""
|
|
53
|
+
|
|
54
|
+
sample_id: UUID
|
|
55
|
+
dataset_id: UUID
|
|
56
|
+
caption_id: UUID
|
|
57
|
+
text: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CaptionDetailsView(CaptionView):
|
|
61
|
+
"""Response model for caption."""
|
|
62
|
+
|
|
63
|
+
sample: CaptionSampleView
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CaptionsListView(BaseModel):
|
|
67
|
+
"""Response model for counted captions."""
|
|
68
|
+
|
|
69
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
70
|
+
|
|
71
|
+
captions: List[CaptionDetailsView] = PydanticField(..., alias="data")
|
|
72
|
+
total_count: int
|
|
73
|
+
next_cursor: Optional[int] = PydanticField(..., alias="nextCursor")
|
lightly_studio/models/dataset.py
CHANGED
|
@@ -19,8 +19,7 @@ from lightly_studio.resolvers.samples_filter import SampleFilter
|
|
|
19
19
|
class DatasetBase(SQLModel):
|
|
20
20
|
"""Base class for the Dataset model."""
|
|
21
21
|
|
|
22
|
-
name: str
|
|
23
|
-
directory: str
|
|
22
|
+
name: str = Field(unique=True, index=True)
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class DatasetCreate(DatasetBase):
|
|
@@ -188,7 +188,7 @@ class SampleMetadataTable(MetadataBase, table=True):
|
|
|
188
188
|
"""This class defines the SampleMetadataTable model."""
|
|
189
189
|
|
|
190
190
|
__tablename__ = "metadata"
|
|
191
|
-
sample_id: UUID = Field(foreign_key="samples.sample_id")
|
|
191
|
+
sample_id: UUID = Field(foreign_key="samples.sample_id", unique=True)
|
|
192
192
|
|
|
193
193
|
sample: SampleTable = Relationship(back_populates="metadata_dict")
|
|
194
194
|
|
lightly_studio/models/sample.py
CHANGED
|
@@ -170,8 +170,8 @@ class SampleView(SQLModel):
|
|
|
170
170
|
file_path_abs: str
|
|
171
171
|
sample_id: UUID
|
|
172
172
|
dataset_id: UUID
|
|
173
|
-
annotations: List["AnnotationView"]
|
|
174
|
-
tags: List[SampleViewTag]
|
|
173
|
+
annotations: List["AnnotationView"]
|
|
174
|
+
tags: List[SampleViewTag]
|
|
175
175
|
metadata_dict: Optional["SampleMetadataView"] = None
|
|
176
176
|
width: int
|
|
177
177
|
height: int
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from .create import create
|
|
4
4
|
from .delete import delete
|
|
5
|
-
from .get_all import get_all
|
|
5
|
+
from .get_all import get_all, get_all_sorted_alphabetically
|
|
6
6
|
from .get_by_id import get_by_id
|
|
7
7
|
from .get_by_ids import get_by_ids
|
|
8
8
|
from .get_by_label_name import get_by_label_name
|
|
@@ -13,6 +13,7 @@ __all__ = [
|
|
|
13
13
|
"create",
|
|
14
14
|
"delete",
|
|
15
15
|
"get_all",
|
|
16
|
+
"get_all_sorted_alphabetically",
|
|
16
17
|
"get_by_id",
|
|
17
18
|
"get_by_ids",
|
|
18
19
|
"get_by_label_name",
|
|
@@ -20,3 +20,18 @@ def get_all(session: Session) -> list[AnnotationLabelTable]:
|
|
|
20
20
|
select(AnnotationLabelTable).order_by(col(AnnotationLabelTable.created_at).asc())
|
|
21
21
|
).all()
|
|
22
22
|
return list(labels) if labels else []
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_all_sorted_alphabetically(session: Session) -> list[AnnotationLabelTable]:
|
|
26
|
+
"""Retrieve all annotation labels sorted alphabetically.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
session (Session): The database session.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
list[AnnotationLabelTable]: A list of annotation labels.
|
|
33
|
+
"""
|
|
34
|
+
labels = session.exec(
|
|
35
|
+
select(AnnotationLabelTable).order_by(col(AnnotationLabelTable.annotation_label_name).asc())
|
|
36
|
+
).all()
|
|
37
|
+
return list(labels) if labels else []
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from lightly_studio.resolvers.annotation_resolver.count_annotations_by_dataset import (
|
|
4
4
|
count_annotations_by_dataset,
|
|
5
5
|
)
|
|
6
|
-
from lightly_studio.resolvers.annotation_resolver.create import create
|
|
7
6
|
from lightly_studio.resolvers.annotation_resolver.create_many import create_many
|
|
8
7
|
from lightly_studio.resolvers.annotation_resolver.delete_annotation import (
|
|
9
8
|
delete_annotation,
|
|
@@ -12,7 +11,7 @@ from lightly_studio.resolvers.annotation_resolver.delete_annotations import (
|
|
|
12
11
|
delete_annotations,
|
|
13
12
|
)
|
|
14
13
|
from lightly_studio.resolvers.annotation_resolver.get_all import get_all
|
|
15
|
-
from lightly_studio.resolvers.annotation_resolver.get_by_id import get_by_id
|
|
14
|
+
from lightly_studio.resolvers.annotation_resolver.get_by_id import get_by_id, get_by_ids
|
|
16
15
|
from lightly_studio.resolvers.annotation_resolver.update_annotation_label import (
|
|
17
16
|
update_annotation_label,
|
|
18
17
|
)
|
|
@@ -22,12 +21,12 @@ from lightly_studio.resolvers.annotation_resolver.update_bounding_box import (
|
|
|
22
21
|
|
|
23
22
|
__all__ = [
|
|
24
23
|
"count_annotations_by_dataset",
|
|
25
|
-
"create",
|
|
26
24
|
"create_many",
|
|
27
25
|
"delete_annotation",
|
|
28
26
|
"delete_annotations",
|
|
29
27
|
"get_all",
|
|
30
28
|
"get_by_id",
|
|
29
|
+
"get_by_ids",
|
|
31
30
|
"update_annotation_label",
|
|
32
31
|
"update_bounding_box",
|
|
33
32
|
]
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
from sqlmodel import Session
|
|
8
9
|
|
|
@@ -24,7 +25,7 @@ from lightly_studio.models.annotation.semantic_segmentation import (
|
|
|
24
25
|
def create_many(
|
|
25
26
|
session: Session,
|
|
26
27
|
annotations: list[AnnotationCreate],
|
|
27
|
-
) -> Sequence[
|
|
28
|
+
) -> Sequence[UUID]:
|
|
28
29
|
"""Create many annotations with object detection details in bulk."""
|
|
29
30
|
# Step 1: Create all base annotations
|
|
30
31
|
base_annotations = []
|
|
@@ -37,7 +38,6 @@ def create_many(
|
|
|
37
38
|
db_base_annotation = AnnotationBaseTable(
|
|
38
39
|
annotation_label_id=annotation_create.annotation_label_id,
|
|
39
40
|
annotation_type=annotation_create.annotation_type,
|
|
40
|
-
annotation_task_id=annotation_create.annotation_task_id,
|
|
41
41
|
confidence=annotation_create.confidence,
|
|
42
42
|
dataset_id=annotation_create.dataset_id,
|
|
43
43
|
sample_id=annotation_create.sample_id,
|
|
@@ -93,4 +93,4 @@ def create_many(
|
|
|
93
93
|
# Commit everything
|
|
94
94
|
session.commit()
|
|
95
95
|
|
|
96
|
-
return base_annotations
|
|
96
|
+
return [annotation.annotation_id for annotation in base_annotations]
|
|
@@ -26,7 +26,7 @@ def delete_annotation(
|
|
|
26
26
|
annotation_id=annotation_id,
|
|
27
27
|
)
|
|
28
28
|
if not annotation:
|
|
29
|
-
|
|
29
|
+
raise ValueError(f"Annotation {annotation_id} not found")
|
|
30
30
|
if annotation.object_detection_details:
|
|
31
31
|
session.delete(annotation.object_detection_details)
|
|
32
32
|
if annotation.instance_segmentation_details:
|
|
@@ -18,14 +18,12 @@ from lightly_studio.resolvers.annotations.annotations_filter import (
|
|
|
18
18
|
|
|
19
19
|
def delete_annotations(
|
|
20
20
|
session: Session,
|
|
21
|
-
annotation_task_ids: list[UUID] | None,
|
|
22
21
|
annotation_label_ids: list[UUID] | None,
|
|
23
22
|
) -> None:
|
|
24
23
|
"""Delete all annotations and their tag links using filters.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
26
|
session: Database session.
|
|
28
|
-
annotation_task_ids: List of annotation task IDs to filter by.
|
|
29
27
|
annotation_label_ids: List of annotation label IDs to filter by.
|
|
30
28
|
"""
|
|
31
29
|
# Find annotation_ids to delete
|
|
@@ -33,9 +31,15 @@ def delete_annotations(
|
|
|
33
31
|
session,
|
|
34
32
|
filters=AnnotationsFilter(
|
|
35
33
|
annotation_label_ids=annotation_label_ids,
|
|
36
|
-
annotation_task_ids=annotation_task_ids,
|
|
37
34
|
),
|
|
38
35
|
).annotations
|
|
36
|
+
for annotation in annotations:
|
|
37
|
+
if annotation.object_detection_details:
|
|
38
|
+
session.delete(annotation.object_detection_details)
|
|
39
|
+
if annotation.instance_segmentation_details:
|
|
40
|
+
session.delete(annotation.instance_segmentation_details)
|
|
41
|
+
if annotation.semantic_segmentation_details:
|
|
42
|
+
session.delete(annotation.semantic_segmentation_details)
|
|
39
43
|
annotation_ids = [annotation.annotation_id for annotation in annotations]
|
|
40
44
|
# TODO(Horatiu, 06/2025): Check if there is a way to delete the links
|
|
41
45
|
# automatically using SQLModel/SQLAlchemy.
|
|
@@ -2,9 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from uuid import UUID
|
|
6
7
|
|
|
7
|
-
from sqlmodel import Session, select
|
|
8
|
+
from sqlmodel import Session, col, select
|
|
8
9
|
|
|
9
10
|
from lightly_studio.models.annotation.annotation_base import (
|
|
10
11
|
AnnotationBaseTable,
|
|
@@ -16,3 +17,20 @@ def get_by_id(session: Session, annotation_id: UUID) -> AnnotationBaseTable | No
|
|
|
16
17
|
return session.exec(
|
|
17
18
|
select(AnnotationBaseTable).where(AnnotationBaseTable.annotation_id == annotation_id)
|
|
18
19
|
).one_or_none()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_by_ids(session: Session, annotation_ids: Sequence[UUID]) -> Sequence[AnnotationBaseTable]:
|
|
23
|
+
"""Retrieve multiple annotations by their IDs.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
session: The database session to use for the query.
|
|
27
|
+
annotation_ids: A list of annotation IDs to retrieve.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of annotations matching the provided IDs.
|
|
31
|
+
"""
|
|
32
|
+
return session.exec(
|
|
33
|
+
select(AnnotationBaseTable).where(
|
|
34
|
+
col(AnnotationBaseTable.annotation_id).in_(annotation_ids)
|
|
35
|
+
)
|
|
36
|
+
).all()
|
|
@@ -111,7 +111,6 @@ def update_annotation_label(
|
|
|
111
111
|
annotation_id=annotation_copy.annotation_id,
|
|
112
112
|
annotation_label_id=annotation_copy.annotation_label_id,
|
|
113
113
|
annotation_type=annotation_copy.annotation_type,
|
|
114
|
-
annotation_task_id=annotation_copy.annotation_task_id,
|
|
115
114
|
confidence=annotation_copy.confidence,
|
|
116
115
|
created_at=annotation_copy.created_at,
|
|
117
116
|
dataset_id=annotation_copy.dataset_id,
|
|
@@ -7,8 +7,7 @@ from uuid import UUID
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
from sqlmodel import col
|
|
9
9
|
|
|
10
|
-
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
|
|
11
|
-
from lightly_studio.models.annotation_task import AnnotationType
|
|
10
|
+
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable, AnnotationType
|
|
12
11
|
from lightly_studio.models.sample import SampleTable
|
|
13
12
|
from lightly_studio.models.tag import TagTable
|
|
14
13
|
from lightly_studio.type_definitions import QueryType
|
|
@@ -30,9 +29,6 @@ class AnnotationsFilter(BaseModel):
|
|
|
30
29
|
default=None,
|
|
31
30
|
description="List of sample tag UUIDs to filter annotations by",
|
|
32
31
|
)
|
|
33
|
-
annotation_task_ids: list[UUID] | None = Field(
|
|
34
|
-
default=None, description="List of annotation task UUIDs"
|
|
35
|
-
)
|
|
36
32
|
|
|
37
33
|
def apply(
|
|
38
34
|
self,
|
|
@@ -51,12 +47,6 @@ class AnnotationsFilter(BaseModel):
|
|
|
51
47
|
if self.dataset_ids:
|
|
52
48
|
query = query.where(col(AnnotationBaseTable.dataset_id).in_(self.dataset_ids))
|
|
53
49
|
|
|
54
|
-
# Filter by annotation task
|
|
55
|
-
if self.annotation_task_ids:
|
|
56
|
-
query = query.where(
|
|
57
|
-
col(AnnotationBaseTable.annotation_task_id).in_(self.annotation_task_ids)
|
|
58
|
-
)
|
|
59
|
-
|
|
60
50
|
# Filter by annotation label
|
|
61
51
|
if self.annotation_label_ids:
|
|
62
52
|
query = query.where(
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Resolvers for caption."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from sqlmodel import Session, col, func, select
|
|
10
|
+
|
|
11
|
+
from lightly_studio.api.routes.api.validators import Paginated
|
|
12
|
+
from lightly_studio.models.caption import CaptionCreate, CaptionTable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GetAllCaptionsResult(BaseModel):
|
|
16
|
+
"""Result wrapper for caption listings."""
|
|
17
|
+
|
|
18
|
+
captions: Sequence[CaptionTable]
|
|
19
|
+
total_count: int
|
|
20
|
+
next_cursor: int | None = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def create_many(session: Session, captions: Sequence[CaptionCreate]) -> list[CaptionTable]:
|
|
24
|
+
"""Create many captions in bulk.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
session: Database session
|
|
28
|
+
captions: The captions to create
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
The created captions
|
|
32
|
+
"""
|
|
33
|
+
if not captions:
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
db_captions = [CaptionTable.model_validate(caption) for caption in captions]
|
|
37
|
+
session.bulk_save_objects(db_captions)
|
|
38
|
+
session.commit()
|
|
39
|
+
return db_captions
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_all(
|
|
43
|
+
session: Session,
|
|
44
|
+
dataset_id: UUID,
|
|
45
|
+
pagination: Paginated | None = None,
|
|
46
|
+
) -> GetAllCaptionsResult:
|
|
47
|
+
"""Get all captions from the database.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
session: Database session
|
|
51
|
+
dataset_id: dataset_id parameter to filter the query
|
|
52
|
+
pagination: Optional pagination parameters
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List of captions matching the filters, total number of captions, next cursor (pagination)
|
|
56
|
+
"""
|
|
57
|
+
query = select(CaptionTable).order_by(
|
|
58
|
+
col(CaptionTable.created_at).asc(),
|
|
59
|
+
col(CaptionTable.caption_id).asc(),
|
|
60
|
+
)
|
|
61
|
+
count_query = select(func.count()).select_from(CaptionTable)
|
|
62
|
+
|
|
63
|
+
query = query.where(CaptionTable.dataset_id == dataset_id)
|
|
64
|
+
count_query = count_query.where(CaptionTable.dataset_id == dataset_id)
|
|
65
|
+
|
|
66
|
+
if pagination is not None:
|
|
67
|
+
query = query.offset(pagination.offset).limit(pagination.limit)
|
|
68
|
+
|
|
69
|
+
captions = session.exec(query).all()
|
|
70
|
+
total_count = session.exec(count_query).one()
|
|
71
|
+
|
|
72
|
+
next_cursor: int | None = None
|
|
73
|
+
if pagination and pagination.offset + pagination.limit < total_count:
|
|
74
|
+
next_cursor = pagination.offset + pagination.limit
|
|
75
|
+
|
|
76
|
+
return GetAllCaptionsResult(
|
|
77
|
+
captions=captions,
|
|
78
|
+
total_count=total_count,
|
|
79
|
+
next_cursor=next_cursor,
|
|
80
|
+
)
|
|
@@ -41,6 +41,9 @@ class ExportFilter(BaseModel):
|
|
|
41
41
|
|
|
42
42
|
def create(session: Session, dataset: DatasetCreate) -> DatasetTable:
|
|
43
43
|
"""Create a new dataset in the database."""
|
|
44
|
+
existing = get_by_name(session=session, name=dataset.name)
|
|
45
|
+
if existing:
|
|
46
|
+
raise ValueError(f"Dataset with name '{dataset.name}' already exists.")
|
|
44
47
|
db_dataset = DatasetTable.model_validate(dataset)
|
|
45
48
|
session.add(db_dataset)
|
|
46
49
|
session.commit()
|
|
@@ -69,12 +72,7 @@ def get_by_id(session: Session, dataset_id: UUID) -> DatasetTable | None:
|
|
|
69
72
|
|
|
70
73
|
def get_by_name(session: Session, name: str) -> DatasetTable | None:
|
|
71
74
|
"""Retrieve a single dataset by name."""
|
|
72
|
-
|
|
73
|
-
if len(datasets) == 0:
|
|
74
|
-
return None
|
|
75
|
-
if len(datasets) > 1:
|
|
76
|
-
raise ValueError(f"Cannot retrieve a dataset, found multiple with name '{name}'.")
|
|
77
|
-
return datasets[0]
|
|
75
|
+
return session.exec(select(DatasetTable).where(DatasetTable.name == name)).one_or_none()
|
|
78
76
|
|
|
79
77
|
|
|
80
78
|
def update(session: Session, dataset_id: UUID, dataset_data: DatasetCreate) -> DatasetTable:
|
|
@@ -84,7 +82,6 @@ def update(session: Session, dataset_id: UUID, dataset_data: DatasetCreate) -> D
|
|
|
84
82
|
raise ValueError(f"Dataset ID was not found '{dataset_id}'.")
|
|
85
83
|
|
|
86
84
|
dataset.name = dataset_data.name
|
|
87
|
-
dataset.directory = dataset_data.directory
|
|
88
85
|
dataset.updated_at = datetime.now(timezone.utc)
|
|
89
86
|
|
|
90
87
|
session.commit()
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
"""Metadata resolver module."""
|
|
2
2
|
|
|
3
3
|
from lightly_studio.resolvers.metadata_resolver.sample import (
|
|
4
|
-
|
|
4
|
+
bulk_update_metadata,
|
|
5
5
|
get_by_sample_id,
|
|
6
6
|
get_value_for_sample,
|
|
7
7
|
set_value_for_sample,
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
|
-
"
|
|
11
|
+
"bulk_update_metadata",
|
|
12
12
|
"get_by_sample_id",
|
|
13
13
|
"get_value_for_sample",
|
|
14
14
|
"set_value_for_sample",
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Resolvers for metadata operations."""
|
|
2
2
|
|
|
3
|
-
from .
|
|
4
|
-
|
|
3
|
+
from .bulk_update_metadata import (
|
|
4
|
+
bulk_update_metadata,
|
|
5
5
|
)
|
|
6
6
|
from .get_by_sample_id import (
|
|
7
7
|
get_by_sample_id,
|
|
@@ -14,7 +14,7 @@ from .set_value_for_sample import (
|
|
|
14
14
|
)
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
"
|
|
17
|
+
"bulk_update_metadata",
|
|
18
18
|
"get_by_sample_id",
|
|
19
19
|
"get_value_for_sample",
|
|
20
20
|
"set_value_for_sample",
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Resolver for operations for setting metadata."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from sqlmodel import Session, col, select
|
|
9
|
+
|
|
10
|
+
from lightly_studio.models.metadata import SampleMetadataTable
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def bulk_update_metadata(
|
|
14
|
+
session: Session,
|
|
15
|
+
sample_metadata: list[tuple[UUID, dict[str, Any]]],
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Bulk insert or update metadata for multiple samples.
|
|
18
|
+
|
|
19
|
+
If a sample does not have metadata, a new metadata row is created.
|
|
20
|
+
If a sample already has metadata, the new key-value pairs are merged with the existing metadata.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
session: The database session.
|
|
24
|
+
sample_metadata: List of (sample_id, metadata_dict) tuples.
|
|
25
|
+
"""
|
|
26
|
+
# TODO(Mihnea, 10/2025): Consider using SQLAlchemy's bulk operations
|
|
27
|
+
# (Session.bulk_insert/update_mappings) if performance becomes a bottleneck.
|
|
28
|
+
if not sample_metadata:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
# Get all existing metadata rows for the given sample IDs.
|
|
32
|
+
sample_ids = [s[0] for s in sample_metadata]
|
|
33
|
+
existing_metadata = session.exec(
|
|
34
|
+
select(SampleMetadataTable).where(col(SampleMetadataTable.sample_id).in_(sample_ids))
|
|
35
|
+
).all()
|
|
36
|
+
sample_id_to_existing_metadata = {meta.sample_id: meta for meta in existing_metadata}
|
|
37
|
+
|
|
38
|
+
for sample_id, new_metadata in sample_metadata:
|
|
39
|
+
metadata = sample_id_to_existing_metadata.get(
|
|
40
|
+
sample_id, SampleMetadataTable(sample_id=sample_id)
|
|
41
|
+
)
|
|
42
|
+
for key, value in new_metadata.items():
|
|
43
|
+
metadata.set_value(key, value)
|
|
44
|
+
session.add(metadata)
|
|
45
|
+
|
|
46
|
+
session.commit()
|