lightly-studio 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +4 -4
- lightly_studio/api/app.py +1 -1
- lightly_studio/api/routes/api/annotation.py +6 -16
- lightly_studio/api/routes/api/annotation_label.py +2 -5
- lightly_studio/api/routes/api/annotation_task.py +4 -5
- lightly_studio/api/routes/api/classifier.py +2 -5
- lightly_studio/api/routes/api/dataset.py +2 -3
- lightly_studio/api/routes/api/dataset_tag.py +2 -3
- lightly_studio/api/routes/api/metadata.py +2 -4
- lightly_studio/api/routes/api/metrics.py +2 -6
- lightly_studio/api/routes/api/sample.py +5 -13
- lightly_studio/api/routes/api/settings.py +2 -6
- lightly_studio/api/routes/images.py +6 -6
- lightly_studio/core/add_samples.py +383 -0
- lightly_studio/core/dataset.py +250 -362
- lightly_studio/core/dataset_query/__init__.py +0 -0
- lightly_studio/core/dataset_query/boolean_expression.py +67 -0
- lightly_studio/core/dataset_query/dataset_query.py +211 -0
- lightly_studio/core/dataset_query/field.py +113 -0
- lightly_studio/core/dataset_query/field_expression.py +79 -0
- lightly_studio/core/dataset_query/match_expression.py +23 -0
- lightly_studio/core/dataset_query/order_by.py +79 -0
- lightly_studio/core/dataset_query/sample_field.py +28 -0
- lightly_studio/core/dataset_query/tags_expression.py +46 -0
- lightly_studio/core/sample.py +159 -32
- lightly_studio/core/start_gui.py +35 -0
- lightly_studio/dataset/edge_embedding_generator.py +13 -8
- lightly_studio/dataset/embedding_generator.py +2 -3
- lightly_studio/dataset/embedding_manager.py +74 -6
- lightly_studio/dataset/fsspec_lister.py +275 -0
- lightly_studio/dataset/loader.py +49 -30
- lightly_studio/dataset/mobileclip_embedding_generator.py +6 -4
- lightly_studio/db_manager.py +145 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BBm0IWdq.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BNTuXSAe.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/2O287xak.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{O-EABkf9.js → 7YNGEs1C.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BBoGk9hq.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRnH9v23.js +92 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bg1Y5eUZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DOlTMNyt.js → BqBqV92V.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C0JiMuYn.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DjfY96ND.js → C98Hk3r5.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{r64xT6ao.js → CG0dMCJi.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{C8I8rFJQ.js → Ccq4ZD0B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cpy-nab_.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bu7uvVrG.js → Crk-jcvV.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cs31G8Qn.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CsKrY2zA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{x9G_hzyY.js → Cur71c3O.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CzgC3GFB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D8GZDMNN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DFRh-Spp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BylOuP6i.js → DRZO-E-T.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{l7KrR96u.js → DcGCxgpH.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bsi3UGy5.js → Df3aMO5B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{hQVEETDE.js → DkR_EZ_B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqUGznj_.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KpAtIldw.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/M1Q1F7bw.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{CDnpyLsT.js → OH7-C_mc.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D6su9Aln.js → gLNdjSzu.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/i0ZZ4z06.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BI-EA5gL.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.CcsRl3cZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.BbO4Zc3r.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{1.B4rNYwVp.js → 1._I9GR805.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.J2RBFrSr.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.Cmqj25a-.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C45iKJHA.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CWHpKonm.js → 3.w9g4AcAx.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.OUWOLQeV.js → 4.BBI8KwnD.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.huHuxdiF.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CrbkRPam.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.FomEdhD6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cb_ADSLk.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{9.CPu3CiBc.js → 9.CajIG5ce.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
- lightly_studio/dist_lightly_studio_view_app/index.html +14 -14
- lightly_studio/examples/example.py +13 -12
- lightly_studio/examples/example_coco.py +13 -0
- lightly_studio/examples/example_metadata.py +83 -98
- lightly_studio/examples/example_selection.py +7 -19
- lightly_studio/examples/example_split_work.py +12 -36
- lightly_studio/examples/{example_v2.py → example_yolo.py} +3 -4
- lightly_studio/models/annotation/annotation_base.py +7 -8
- lightly_studio/models/annotation/instance_segmentation.py +8 -8
- lightly_studio/models/annotation/object_detection.py +4 -4
- lightly_studio/models/dataset.py +6 -2
- lightly_studio/models/sample.py +10 -3
- lightly_studio/resolvers/dataset_resolver.py +10 -0
- lightly_studio/resolvers/embedding_model_resolver.py +22 -0
- lightly_studio/resolvers/sample_resolver.py +53 -9
- lightly_studio/resolvers/tag_resolver.py +23 -0
- lightly_studio/selection/select.py +55 -46
- lightly_studio/selection/select_via_db.py +23 -19
- lightly_studio/selection/selection_config.py +6 -3
- lightly_studio/services/annotations_service/__init__.py +4 -0
- lightly_studio/services/annotations_service/update_annotation.py +21 -32
- lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
- lightly_studio-0.3.2.dist-info/METADATA +689 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/RECORD +104 -91
- lightly_studio/api/db.py +0 -133
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +0 -93
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +0 -3
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +0 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +0 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +0 -6
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +0 -1
- lightly_studio-0.3.1.dist-info/METADATA +0 -520
- /lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{OpenSans- → OpenSans-Medium.DVUZMR_6.ttf} +0 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/WHEEL +0 -0
|
@@ -21,6 +21,28 @@ def create(session: Session, embedding_model: EmbeddingModelCreate) -> Embedding
|
|
|
21
21
|
return db_embedding_model
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def get_or_create(session: Session, embedding_model: EmbeddingModelCreate) -> EmbeddingModelTable:
|
|
25
|
+
"""Retrieve an existing EmbeddingModel by hash or create a new one if it does not exist."""
|
|
26
|
+
db_model = get_by_model_hash(
|
|
27
|
+
session=session, embedding_model_hash=embedding_model.embedding_model_hash
|
|
28
|
+
)
|
|
29
|
+
if db_model is None:
|
|
30
|
+
return create(session=session, embedding_model=embedding_model)
|
|
31
|
+
|
|
32
|
+
# Validate that the existing model matches the provided data.
|
|
33
|
+
if (
|
|
34
|
+
db_model.name != embedding_model.name
|
|
35
|
+
or db_model.parameter_count_in_mb != embedding_model.parameter_count_in_mb
|
|
36
|
+
or db_model.embedding_dimension != embedding_model.embedding_dimension
|
|
37
|
+
# TODO(Michal, 09/2025): Allow same model for different datasets.
|
|
38
|
+
or db_model.dataset_id != embedding_model.dataset_id
|
|
39
|
+
):
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"An embedding model with the same hash but different parameters already exists."
|
|
42
|
+
)
|
|
43
|
+
return db_model
|
|
44
|
+
|
|
45
|
+
|
|
24
46
|
def get_all_by_dataset_id(session: Session, dataset_id: UUID) -> list[EmbeddingModelTable]:
|
|
25
47
|
"""Retrieve all embedding models."""
|
|
26
48
|
embedding_models = session.exec(
|
|
@@ -7,9 +7,11 @@ from datetime import datetime, timezone
|
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
9
|
from pydantic import BaseModel
|
|
10
|
+
from sqlalchemy.orm import joinedload, selectinload
|
|
10
11
|
from sqlmodel import Session, col, func, select
|
|
11
12
|
from sqlmodel.sql.expression import Select
|
|
12
13
|
|
|
14
|
+
from lightly_studio.api.routes.api.validators import Paginated
|
|
13
15
|
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
|
|
14
16
|
from lightly_studio.models.annotation_label import AnnotationLabelTable
|
|
15
17
|
from lightly_studio.models.embedding_model import EmbeddingModelTable
|
|
@@ -36,6 +38,22 @@ def create_many(session: Session, samples: list[SampleCreate]) -> list[SampleTab
|
|
|
36
38
|
return db_samples
|
|
37
39
|
|
|
38
40
|
|
|
41
|
+
def filter_new_paths(session: Session, file_paths_abs: list[str]) -> tuple[list[str], list[str]]:
|
|
42
|
+
"""Return a) file_path_abs that do not already exist in the database and b) those that do."""
|
|
43
|
+
existing_file_paths_abs = set(
|
|
44
|
+
session.exec(
|
|
45
|
+
select(col(SampleTable.file_path_abs)).where(
|
|
46
|
+
col(SampleTable.file_path_abs).in_(file_paths_abs)
|
|
47
|
+
)
|
|
48
|
+
).all()
|
|
49
|
+
)
|
|
50
|
+
file_paths_abs_set = set(file_paths_abs)
|
|
51
|
+
return (
|
|
52
|
+
list(file_paths_abs_set - existing_file_paths_abs), # paths that are not in the DB
|
|
53
|
+
list(file_paths_abs_set & existing_file_paths_abs), # paths that are already in the DB
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
39
57
|
def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTable | None:
|
|
40
58
|
"""Retrieve a single sample by ID."""
|
|
41
59
|
return session.exec(
|
|
@@ -45,6 +63,13 @@ def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTabl
|
|
|
45
63
|
).one_or_none()
|
|
46
64
|
|
|
47
65
|
|
|
66
|
+
def count_by_dataset_id(session: Session, dataset_id: UUID) -> int:
|
|
67
|
+
"""Count the number of samples in a dataset."""
|
|
68
|
+
return session.exec(
|
|
69
|
+
select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
|
|
70
|
+
).one()
|
|
71
|
+
|
|
72
|
+
|
|
48
73
|
def get_many_by_id(session: Session, sample_ids: list[UUID]) -> list[SampleTable]:
|
|
49
74
|
"""Retrieve multiple samples by their IDs.
|
|
50
75
|
|
|
@@ -63,19 +88,33 @@ class GetAllSamplesByDatasetIdResult(BaseModel):
|
|
|
63
88
|
|
|
64
89
|
samples: Sequence[SampleTable]
|
|
65
90
|
total_count: int
|
|
91
|
+
next_cursor: int | None = None
|
|
66
92
|
|
|
67
93
|
|
|
68
94
|
def get_all_by_dataset_id( # noqa: PLR0913
|
|
69
95
|
session: Session,
|
|
70
96
|
dataset_id: UUID,
|
|
71
|
-
|
|
72
|
-
limit: int | None = None,
|
|
97
|
+
pagination: Paginated | None = None,
|
|
73
98
|
filters: SampleFilter | None = None,
|
|
74
99
|
text_embedding: list[float] | None = None,
|
|
75
100
|
sample_ids: list[UUID] | None = None,
|
|
76
101
|
) -> GetAllSamplesByDatasetIdResult:
|
|
77
102
|
"""Retrieve samples for a specific dataset with optional filtering."""
|
|
78
|
-
samples_query =
|
|
103
|
+
samples_query = (
|
|
104
|
+
select(SampleTable)
|
|
105
|
+
.options(
|
|
106
|
+
selectinload(SampleTable.annotations).options(
|
|
107
|
+
joinedload(AnnotationBaseTable.annotation_label),
|
|
108
|
+
joinedload(AnnotationBaseTable.object_detection_details),
|
|
109
|
+
joinedload(AnnotationBaseTable.instance_segmentation_details),
|
|
110
|
+
joinedload(AnnotationBaseTable.semantic_segmentation_details),
|
|
111
|
+
),
|
|
112
|
+
selectinload(SampleTable.tags),
|
|
113
|
+
# Ignore type checker error below as it's a false positive caused by TYPE_CHECKING.
|
|
114
|
+
joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
|
|
115
|
+
)
|
|
116
|
+
.where(SampleTable.dataset_id == dataset_id)
|
|
117
|
+
)
|
|
79
118
|
total_count_query = (
|
|
80
119
|
select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
|
|
81
120
|
)
|
|
@@ -120,15 +159,20 @@ def get_all_by_dataset_id( # noqa: PLR0913
|
|
|
120
159
|
col(SampleTable.created_at).asc(), col(SampleTable.sample_id).asc()
|
|
121
160
|
)
|
|
122
161
|
|
|
123
|
-
#
|
|
124
|
-
if
|
|
125
|
-
samples_query = samples_query.offset(offset)
|
|
126
|
-
|
|
127
|
-
|
|
162
|
+
# Apply pagination if provided
|
|
163
|
+
if pagination is not None:
|
|
164
|
+
samples_query = samples_query.offset(pagination.offset).limit(pagination.limit)
|
|
165
|
+
|
|
166
|
+
total_count = session.exec(total_count_query).one()
|
|
167
|
+
|
|
168
|
+
next_cursor = None
|
|
169
|
+
if pagination and pagination.offset + pagination.limit < total_count:
|
|
170
|
+
next_cursor = pagination.offset + pagination.limit
|
|
128
171
|
|
|
129
172
|
return GetAllSamplesByDatasetIdResult(
|
|
130
173
|
samples=session.exec(samples_query).all(),
|
|
131
|
-
total_count=
|
|
174
|
+
total_count=total_count,
|
|
175
|
+
next_cursor=next_cursor,
|
|
132
176
|
)
|
|
133
177
|
|
|
134
178
|
|
|
@@ -274,3 +274,26 @@ def remove_annotation_ids_from_tag_id(
|
|
|
274
274
|
session.commit()
|
|
275
275
|
session.refresh(tag)
|
|
276
276
|
return tag
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_or_create_sample_tag_by_name(
|
|
280
|
+
session: Session,
|
|
281
|
+
dataset_id: UUID,
|
|
282
|
+
tag_name: str,
|
|
283
|
+
) -> TagTable:
|
|
284
|
+
"""Get an existing sample tag by name or create a new one if it doesn't exist.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
session: Database session for executing queries.
|
|
288
|
+
dataset_id: The dataset ID to search/create the tag for.
|
|
289
|
+
tag_name: Name of the tag to get or create.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The existing or newly created sample tag.
|
|
293
|
+
"""
|
|
294
|
+
existing_tag = get_by_name(session=session, tag_name=tag_name, dataset_id=dataset_id)
|
|
295
|
+
if existing_tag:
|
|
296
|
+
return existing_tag
|
|
297
|
+
|
|
298
|
+
new_tag = TagCreate(name=tag_name, dataset_id=dataset_id, kind="sample")
|
|
299
|
+
return create(session=session, tag=new_tag)
|
|
@@ -1,96 +1,105 @@
|
|
|
1
|
-
"""Provides the user python interface to selection."""
|
|
1
|
+
"""Provides the user python interface to selection bound to sample ids."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from typing import Final
|
|
5
7
|
from uuid import UUID
|
|
6
8
|
|
|
7
9
|
from sqlmodel import Session
|
|
8
10
|
|
|
9
|
-
from lightly_studio.resolvers.samples_filter import SampleFilter
|
|
10
11
|
from lightly_studio.selection.select_via_db import select_via_database
|
|
11
12
|
from lightly_studio.selection.selection_config import (
|
|
12
13
|
EmbeddingDiversityStrategy,
|
|
14
|
+
MetadataWeightingStrategy,
|
|
13
15
|
SelectionConfig,
|
|
14
16
|
SelectionStrategy,
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Selection:
|
|
19
|
-
"""
|
|
21
|
+
"""Selection interface for candidate sample ids."""
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
# dataset_view = ...
|
|
29
|
-
# dataset_view.select.diverse(...)
|
|
30
|
-
#
|
|
31
|
-
# See https://docs.google.com/document/d/1ZRICdFmfJmxUBy3FFoeUWsAgsCNWDHg8CK5MJiGmX74/edit?tab=t.kbfvnrepsuf#bookmark=id.8klhhwr5q4dp
|
|
32
|
-
|
|
33
|
-
def __init__(self, dataset_id: UUID, session: Session):
|
|
34
|
-
"""Creates the interface to run selection.
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
dataset_id: UUID,
|
|
26
|
+
session: Session,
|
|
27
|
+
input_sample_ids: Iterable[UUID],
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Create the selection interface.
|
|
35
30
|
|
|
36
31
|
Args:
|
|
37
|
-
dataset_id:
|
|
38
|
-
session:
|
|
32
|
+
dataset_id: Dataset in which the selection is performed.
|
|
33
|
+
session: Database session to resolve selection dependencies.
|
|
34
|
+
input_sample_ids: Candidate sample ids considered for selection.
|
|
35
|
+
The iterable is consumed immediately to capture a stable snapshot.
|
|
36
|
+
"""
|
|
37
|
+
self._dataset_id: Final[UUID] = dataset_id
|
|
38
|
+
self._session: Final[Session] = session
|
|
39
|
+
self._input_sample_ids: list[UUID] = list(input_sample_ids)
|
|
39
40
|
|
|
41
|
+
def metadata_weighting(
|
|
42
|
+
self,
|
|
43
|
+
n_samples_to_select: int,
|
|
44
|
+
selection_result_tag_name: str,
|
|
45
|
+
metadata_key: str,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Select a subset based on numeric metadata weights.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
n_samples_to_select: Number of samples to select.
|
|
51
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
52
|
+
metadata_key: Metadata key used as weights (float or int values).
|
|
40
53
|
"""
|
|
41
|
-
|
|
42
|
-
self.
|
|
54
|
+
strategy = MetadataWeightingStrategy(metadata_key=metadata_key)
|
|
55
|
+
self.multi_strategies(
|
|
56
|
+
n_samples_to_select=n_samples_to_select,
|
|
57
|
+
selection_result_tag_name=selection_result_tag_name,
|
|
58
|
+
selection_strategies=[strategy],
|
|
59
|
+
)
|
|
43
60
|
|
|
44
61
|
def diverse(
|
|
45
62
|
self,
|
|
46
63
|
n_samples_to_select: int,
|
|
47
64
|
selection_result_tag_name: str,
|
|
48
65
|
embedding_model_name: str | None = None,
|
|
49
|
-
sample_filter: SampleFilter | None = None,
|
|
50
66
|
) -> None:
|
|
51
|
-
"""
|
|
67
|
+
"""Select a diverse subset using embeddings.
|
|
52
68
|
|
|
53
69
|
Args:
|
|
54
|
-
n_samples_to_select:
|
|
55
|
-
selection_result_tag_name:
|
|
56
|
-
embedding_model_name:
|
|
57
|
-
|
|
58
|
-
If None, assert that there is only one embedding model and uses it.
|
|
59
|
-
sample_filter: An optional filter to apply to the samples.
|
|
70
|
+
n_samples_to_select: Number of samples to select.
|
|
71
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
72
|
+
embedding_model_name: Optional embedding model name. If None, uses the only
|
|
73
|
+
available model or raises if multiple exist.
|
|
60
74
|
"""
|
|
61
75
|
strategy = EmbeddingDiversityStrategy(embedding_model_name=embedding_model_name)
|
|
62
|
-
|
|
63
|
-
dataset_id=self.dataset_id,
|
|
76
|
+
self.multi_strategies(
|
|
64
77
|
n_samples_to_select=n_samples_to_select,
|
|
65
78
|
selection_result_tag_name=selection_result_tag_name,
|
|
66
|
-
|
|
67
|
-
strategies=[strategy],
|
|
79
|
+
selection_strategies=[strategy],
|
|
68
80
|
)
|
|
69
|
-
select_via_database(session=self.session, config=selection_config)
|
|
70
81
|
|
|
71
82
|
def multi_strategies(
|
|
72
83
|
self,
|
|
73
84
|
n_samples_to_select: int,
|
|
74
85
|
selection_result_tag_name: str,
|
|
75
86
|
selection_strategies: list[SelectionStrategy],
|
|
76
|
-
sample_filter: SampleFilter | None = None,
|
|
77
87
|
) -> None:
|
|
78
|
-
"""Select a subset
|
|
88
|
+
"""Select a subset based on multiple strategies.
|
|
79
89
|
|
|
80
90
|
Args:
|
|
81
|
-
n_samples_to_select:
|
|
82
|
-
selection_result_tag_name:
|
|
83
|
-
selection_strategies:
|
|
84
|
-
Selection strategies to use for the selection. They can be created after
|
|
85
|
-
importing them from `lightly_studio.selection.selection_config`.
|
|
86
|
-
sample_filter: An optional filter to apply to the samples.
|
|
87
|
-
|
|
91
|
+
n_samples_to_select: Number of samples to select.
|
|
92
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
93
|
+
selection_strategies: Strategies to compose for selection.
|
|
88
94
|
"""
|
|
89
95
|
config = SelectionConfig(
|
|
90
|
-
dataset_id=self.
|
|
96
|
+
dataset_id=self._dataset_id,
|
|
91
97
|
n_samples_to_select=n_samples_to_select,
|
|
92
98
|
selection_result_tag_name=selection_result_tag_name,
|
|
93
|
-
sample_filter=sample_filter,
|
|
94
99
|
strategies=selection_strategies,
|
|
95
100
|
)
|
|
96
|
-
select_via_database(
|
|
101
|
+
select_via_database(
|
|
102
|
+
session=self._session,
|
|
103
|
+
config=config,
|
|
104
|
+
input_sample_ids=self._input_sample_ids,
|
|
105
|
+
)
|
|
@@ -3,29 +3,33 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import datetime
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
from sqlmodel import Session
|
|
8
9
|
|
|
9
10
|
from lightly_studio.models.tag import TagCreate
|
|
10
11
|
from lightly_studio.resolvers import (
|
|
11
12
|
embedding_model_resolver,
|
|
13
|
+
metadata_resolver,
|
|
12
14
|
sample_embedding_resolver,
|
|
13
|
-
sample_resolver,
|
|
14
15
|
tag_resolver,
|
|
15
16
|
)
|
|
16
17
|
from lightly_studio.selection.mundig import Mundig
|
|
17
18
|
from lightly_studio.selection.selection_config import (
|
|
18
19
|
EmbeddingDiversityStrategy,
|
|
20
|
+
MetadataWeightingStrategy,
|
|
19
21
|
SelectionConfig,
|
|
20
22
|
)
|
|
21
23
|
|
|
22
24
|
|
|
23
|
-
def select_via_database(
|
|
24
|
-
|
|
25
|
+
def select_via_database(
|
|
26
|
+
session: Session, config: SelectionConfig, input_sample_ids: list[UUID]
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Run selection using the provided candidate sample ids.
|
|
25
29
|
|
|
26
|
-
First resolves the selection config to
|
|
30
|
+
First resolves the selection config to concrete database values.
|
|
27
31
|
Then calls Mundig to run the selection with pure values.
|
|
28
|
-
|
|
32
|
+
Finally creates a tag for the selected set.
|
|
29
33
|
"""
|
|
30
34
|
# Check if the tag name is already used
|
|
31
35
|
existing_tag = tag_resolver.get_by_name(
|
|
@@ -40,18 +44,7 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
|
|
|
40
44
|
)
|
|
41
45
|
raise ValueError(msg)
|
|
42
46
|
|
|
43
|
-
|
|
44
|
-
# the latter is implemented.
|
|
45
|
-
# See https://linear.app/lightly/issue/LIG-7292/story-python-ui-mvp1-without-datasetquery-and-sample
|
|
46
|
-
samples = sample_resolver.get_all_by_dataset_id(
|
|
47
|
-
session,
|
|
48
|
-
limit=None,
|
|
49
|
-
dataset_id=config.dataset_id,
|
|
50
|
-
filters=config.sample_filter,
|
|
51
|
-
).samples
|
|
52
|
-
sample_ids = [s.sample_id for s in samples]
|
|
53
|
-
|
|
54
|
-
n_samples_to_select = min(config.n_samples_to_select, len(sample_ids))
|
|
47
|
+
n_samples_to_select = min(config.n_samples_to_select, len(input_sample_ids))
|
|
55
48
|
if n_samples_to_select == 0:
|
|
56
49
|
print("No samples available for selection.")
|
|
57
50
|
return
|
|
@@ -66,16 +59,27 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
|
|
|
66
59
|
).embedding_model_id
|
|
67
60
|
embedding_tables = sample_embedding_resolver.get_by_sample_ids(
|
|
68
61
|
session=session,
|
|
69
|
-
sample_ids=
|
|
62
|
+
sample_ids=input_sample_ids,
|
|
70
63
|
embedding_model_id=embedding_model_id,
|
|
71
64
|
)
|
|
72
65
|
embeddings = [e.embedding for e in embedding_tables]
|
|
73
66
|
mundig.add_diversity(embeddings=embeddings, strength=strat.strength)
|
|
67
|
+
elif isinstance(strat, MetadataWeightingStrategy):
|
|
68
|
+
key = strat.metadata_key
|
|
69
|
+
weights = []
|
|
70
|
+
for sample_id in input_sample_ids:
|
|
71
|
+
weight = metadata_resolver.get_value_for_sample(session, sample_id, key)
|
|
72
|
+
if not isinstance(weight, (float, int)):
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Metadata {key} is not a number, only numbers can be used as weights"
|
|
75
|
+
)
|
|
76
|
+
weights.append(float(weight))
|
|
77
|
+
mundig.add_weighting(weights, strength=strat.strength)
|
|
74
78
|
else:
|
|
75
79
|
raise ValueError(f"Selection strategy of type {type(strat)} is unknown.")
|
|
76
80
|
|
|
77
81
|
selected_indices = mundig.run(n_samples=n_samples_to_select)
|
|
78
|
-
selected_sample_ids = [
|
|
82
|
+
selected_sample_ids = [input_sample_ids[i] for i in selected_indices]
|
|
79
83
|
|
|
80
84
|
datetime_str = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
|
|
81
85
|
tag_description = f"Selected at {datetime_str} UTC"
|
|
@@ -6,14 +6,11 @@ from uuid import UUID
|
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel
|
|
8
8
|
|
|
9
|
-
from lightly_studio.resolvers.samples_filter import SampleFilter
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
class SelectionConfig(BaseModel):
|
|
13
11
|
"""Configuration for the selection process."""
|
|
14
12
|
|
|
15
13
|
dataset_id: UUID
|
|
16
|
-
sample_filter: SampleFilter | None = None
|
|
17
14
|
n_samples_to_select: int
|
|
18
15
|
selection_result_tag_name: str
|
|
19
16
|
strategies: list[SelectionStrategy]
|
|
@@ -29,3 +26,9 @@ class EmbeddingDiversityStrategy(SelectionStrategy):
|
|
|
29
26
|
"""Selection strategy based on embedding diversity."""
|
|
30
27
|
|
|
31
28
|
embedding_model_name: str | None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MetadataWeightingStrategy(SelectionStrategy):
|
|
32
|
+
"""Selection strategy based on metadata weighting."""
|
|
33
|
+
|
|
34
|
+
metadata_key: str
|
|
@@ -6,6 +6,9 @@ from lightly_studio.services.annotations_service.get_annotation_by_id import (
|
|
|
6
6
|
from lightly_studio.services.annotations_service.update_annotation import (
|
|
7
7
|
update_annotation,
|
|
8
8
|
)
|
|
9
|
+
from lightly_studio.services.annotations_service.update_annotation_bounding_box import (
|
|
10
|
+
update_annotation_bounding_box,
|
|
11
|
+
)
|
|
9
12
|
from lightly_studio.services.annotations_service.update_annotation_label import (
|
|
10
13
|
update_annotation_label,
|
|
11
14
|
)
|
|
@@ -16,6 +19,7 @@ from lightly_studio.services.annotations_service.update_annotations import (
|
|
|
16
19
|
__all__ = [
|
|
17
20
|
"get_annotation_by_id",
|
|
18
21
|
"update_annotation",
|
|
22
|
+
"update_annotation_bounding_box",
|
|
19
23
|
"update_annotation_label",
|
|
20
24
|
"update_annotations",
|
|
21
25
|
]
|
|
@@ -10,6 +10,7 @@ from sqlmodel import Session
|
|
|
10
10
|
from lightly_studio.models.annotation.annotation_base import (
|
|
11
11
|
AnnotationBaseTable,
|
|
12
12
|
)
|
|
13
|
+
from lightly_studio.resolvers.annotation_resolver.update_bounding_box import BoundingBoxCoordinates
|
|
13
14
|
from lightly_studio.services import annotations_service
|
|
14
15
|
|
|
15
16
|
|
|
@@ -18,11 +19,8 @@ class AnnotationUpdate(BaseModel):
|
|
|
18
19
|
|
|
19
20
|
annotation_id: UUID
|
|
20
21
|
dataset_id: UUID
|
|
21
|
-
label_name: str | None
|
|
22
|
-
|
|
23
|
-
y: int | None = None
|
|
24
|
-
width: int | None = None
|
|
25
|
-
height: int | None = None
|
|
22
|
+
label_name: str | None = None
|
|
23
|
+
bounding_box: BoundingBoxCoordinates | None = None
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
def update_annotation(session: Session, annotation_update: AnnotationUpdate) -> AnnotationBaseTable:
|
|
@@ -36,30 +34,21 @@ def update_annotation(session: Session, annotation_update: AnnotationUpdate) ->
|
|
|
36
34
|
The updated annotation.
|
|
37
35
|
|
|
38
36
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# "All bounding box coordinates (x, y, width, height) "
|
|
58
|
-
# "must be provided for updating this annotation type"
|
|
59
|
-
# )
|
|
60
|
-
|
|
61
|
-
return annotations_service.update_annotation_label(
|
|
62
|
-
session,
|
|
63
|
-
annotation_update.annotation_id,
|
|
64
|
-
annotation_update.label_name,
|
|
65
|
-
)
|
|
37
|
+
result = None
|
|
38
|
+
if annotation_update.label_name is not None:
|
|
39
|
+
result = annotations_service.update_annotation_label(
|
|
40
|
+
session,
|
|
41
|
+
annotation_update.annotation_id,
|
|
42
|
+
annotation_update.label_name,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if annotation_update.bounding_box is not None:
|
|
46
|
+
result = annotations_service.update_annotation_bounding_box(
|
|
47
|
+
session,
|
|
48
|
+
annotation_update.annotation_id,
|
|
49
|
+
bounding_box=annotation_update.bounding_box,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if result is None:
|
|
53
|
+
raise ValueError("No updates provided for the annotation.")
|
|
54
|
+
return result
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Update the bounding box of an annotation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
from sqlmodel import Session
|
|
8
|
+
|
|
9
|
+
from lightly_studio.models.annotation.annotation_base import (
|
|
10
|
+
AnnotationBaseTable,
|
|
11
|
+
)
|
|
12
|
+
from lightly_studio.resolvers import (
|
|
13
|
+
annotation_resolver,
|
|
14
|
+
)
|
|
15
|
+
from lightly_studio.resolvers.annotation_resolver.update_bounding_box import BoundingBoxCoordinates
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def update_annotation_bounding_box(
|
|
19
|
+
session: Session, annotation_id: UUID, bounding_box: BoundingBoxCoordinates
|
|
20
|
+
) -> AnnotationBaseTable:
|
|
21
|
+
"""Update the bounding box of an annotation.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
session: Database session for executing the operation.
|
|
25
|
+
annotation_id: UUID of the annotation to update.
|
|
26
|
+
bounding_box: New bounding box coordinates to assign to the annotation.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The updated annotation with the new bounding box assigned.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
return annotation_resolver.update_bounding_box(
|
|
33
|
+
session,
|
|
34
|
+
annotation_id,
|
|
35
|
+
bounding_box,
|
|
36
|
+
)
|