lightly-studio 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +4 -4
- lightly_studio/api/app.py +1 -1
- lightly_studio/api/routes/api/annotation.py +6 -16
- lightly_studio/api/routes/api/annotation_label.py +2 -5
- lightly_studio/api/routes/api/annotation_task.py +4 -5
- lightly_studio/api/routes/api/classifier.py +2 -5
- lightly_studio/api/routes/api/dataset.py +2 -3
- lightly_studio/api/routes/api/dataset_tag.py +2 -3
- lightly_studio/api/routes/api/metadata.py +2 -4
- lightly_studio/api/routes/api/metrics.py +2 -6
- lightly_studio/api/routes/api/sample.py +5 -13
- lightly_studio/api/routes/api/settings.py +2 -6
- lightly_studio/api/routes/images.py +6 -6
- lightly_studio/core/add_samples.py +383 -0
- lightly_studio/core/dataset.py +250 -362
- lightly_studio/core/dataset_query/__init__.py +0 -0
- lightly_studio/core/dataset_query/boolean_expression.py +67 -0
- lightly_studio/core/dataset_query/dataset_query.py +211 -0
- lightly_studio/core/dataset_query/field.py +113 -0
- lightly_studio/core/dataset_query/field_expression.py +79 -0
- lightly_studio/core/dataset_query/match_expression.py +23 -0
- lightly_studio/core/dataset_query/order_by.py +79 -0
- lightly_studio/core/dataset_query/sample_field.py +28 -0
- lightly_studio/core/dataset_query/tags_expression.py +46 -0
- lightly_studio/core/sample.py +159 -32
- lightly_studio/core/start_gui.py +35 -0
- lightly_studio/dataset/edge_embedding_generator.py +13 -8
- lightly_studio/dataset/embedding_generator.py +2 -3
- lightly_studio/dataset/embedding_manager.py +74 -6
- lightly_studio/dataset/fsspec_lister.py +275 -0
- lightly_studio/dataset/loader.py +49 -30
- lightly_studio/dataset/mobileclip_embedding_generator.py +6 -4
- lightly_studio/db_manager.py +145 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BBm0IWdq.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BNTuXSAe.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/2O287xak.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{O-EABkf9.js → 7YNGEs1C.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BBoGk9hq.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRnH9v23.js +92 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bg1Y5eUZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DOlTMNyt.js → BqBqV92V.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C0JiMuYn.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DjfY96ND.js → C98Hk3r5.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{r64xT6ao.js → CG0dMCJi.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{C8I8rFJQ.js → Ccq4ZD0B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cpy-nab_.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bu7uvVrG.js → Crk-jcvV.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cs31G8Qn.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CsKrY2zA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{x9G_hzyY.js → Cur71c3O.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CzgC3GFB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D8GZDMNN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DFRh-Spp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BylOuP6i.js → DRZO-E-T.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{l7KrR96u.js → DcGCxgpH.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bsi3UGy5.js → Df3aMO5B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{hQVEETDE.js → DkR_EZ_B.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqUGznj_.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KpAtIldw.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/M1Q1F7bw.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{CDnpyLsT.js → OH7-C_mc.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D6su9Aln.js → gLNdjSzu.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/i0ZZ4z06.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BI-EA5gL.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.CcsRl3cZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.BbO4Zc3r.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{1.B4rNYwVp.js → 1._I9GR805.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.J2RBFrSr.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.Cmqj25a-.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C45iKJHA.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CWHpKonm.js → 3.w9g4AcAx.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.OUWOLQeV.js → 4.BBI8KwnD.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.huHuxdiF.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CrbkRPam.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.FomEdhD6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cb_ADSLk.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{9.CPu3CiBc.js → 9.CajIG5ce.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
- lightly_studio/dist_lightly_studio_view_app/index.html +14 -14
- lightly_studio/examples/example.py +13 -12
- lightly_studio/examples/example_coco.py +13 -0
- lightly_studio/examples/example_metadata.py +83 -98
- lightly_studio/examples/example_selection.py +7 -19
- lightly_studio/examples/example_split_work.py +12 -36
- lightly_studio/examples/{example_v2.py → example_yolo.py} +3 -4
- lightly_studio/models/annotation/annotation_base.py +7 -8
- lightly_studio/models/annotation/instance_segmentation.py +8 -8
- lightly_studio/models/annotation/object_detection.py +4 -4
- lightly_studio/models/dataset.py +6 -2
- lightly_studio/models/sample.py +10 -3
- lightly_studio/resolvers/dataset_resolver.py +10 -0
- lightly_studio/resolvers/embedding_model_resolver.py +22 -0
- lightly_studio/resolvers/sample_resolver.py +53 -9
- lightly_studio/resolvers/tag_resolver.py +23 -0
- lightly_studio/selection/select.py +55 -46
- lightly_studio/selection/select_via_db.py +23 -19
- lightly_studio/selection/selection_config.py +6 -3
- lightly_studio/services/annotations_service/__init__.py +4 -0
- lightly_studio/services/annotations_service/update_annotation.py +21 -32
- lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
- lightly_studio-0.3.2.dist-info/METADATA +689 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/RECORD +104 -91
- lightly_studio/api/db.py +0 -133
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +0 -93
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +0 -3
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +0 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +0 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +0 -6
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +0 -1
- lightly_studio-0.3.1.dist-info/METADATA +0 -520
- /lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{OpenSans- → OpenSans-Medium.DVUZMR_6.ttf} +0 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/WHEEL +0 -0
lightly_studio/core/dataset.py
CHANGED
|
@@ -2,44 +2,39 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from dataclasses import dataclass
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import Iterable
|
|
6
|
+
from typing import Iterable, Iterator
|
|
8
7
|
from uuid import UUID
|
|
9
8
|
|
|
10
|
-
import PIL
|
|
11
9
|
from labelformat.formats import (
|
|
12
10
|
COCOInstanceSegmentationInput,
|
|
13
11
|
COCOObjectDetectionInput,
|
|
14
12
|
YOLOv8ObjectDetectionInput,
|
|
15
13
|
)
|
|
16
|
-
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
|
|
17
|
-
from labelformat.model.bounding_box import BoundingBoxFormat
|
|
18
|
-
from labelformat.model.image import Image
|
|
19
14
|
from labelformat.model.instance_segmentation import (
|
|
20
|
-
ImageInstanceSegmentation,
|
|
21
15
|
InstanceSegmentationInput,
|
|
22
16
|
)
|
|
23
|
-
from labelformat.model.multipolygon import MultiPolygon
|
|
24
17
|
from labelformat.model.object_detection import (
|
|
25
|
-
ImageObjectDetection,
|
|
26
18
|
ObjectDetectionInput,
|
|
27
19
|
)
|
|
28
|
-
from sqlmodel import Session
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
from lightly_studio.api
|
|
32
|
-
from lightly_studio.
|
|
33
|
-
from lightly_studio.
|
|
20
|
+
from sqlmodel import Session, select
|
|
21
|
+
|
|
22
|
+
from lightly_studio import db_manager
|
|
23
|
+
from lightly_studio.api import features
|
|
24
|
+
from lightly_studio.core import add_samples
|
|
25
|
+
from lightly_studio.core.dataset_query.dataset_query import DatasetQuery
|
|
26
|
+
from lightly_studio.core.dataset_query.match_expression import MatchExpression
|
|
27
|
+
from lightly_studio.core.dataset_query.order_by import OrderByExpression
|
|
28
|
+
from lightly_studio.core.sample import Sample
|
|
29
|
+
from lightly_studio.dataset import fsspec_lister
|
|
30
|
+
from lightly_studio.dataset.embedding_manager import EmbeddingManagerProvider
|
|
34
31
|
from lightly_studio.models.annotation_task import (
|
|
35
32
|
AnnotationTaskTable,
|
|
36
33
|
AnnotationType,
|
|
37
34
|
)
|
|
38
35
|
from lightly_studio.models.dataset import DatasetCreate, DatasetTable
|
|
39
|
-
from lightly_studio.models.sample import
|
|
36
|
+
from lightly_studio.models.sample import SampleTable
|
|
40
37
|
from lightly_studio.resolvers import (
|
|
41
|
-
annotation_label_resolver,
|
|
42
|
-
annotation_resolver,
|
|
43
38
|
annotation_task_resolver,
|
|
44
39
|
dataset_resolver,
|
|
45
40
|
sample_resolver,
|
|
@@ -47,97 +42,201 @@ from lightly_studio.resolvers import (
|
|
|
47
42
|
from lightly_studio.type_definitions import PathLike
|
|
48
43
|
|
|
49
44
|
# Constants
|
|
50
|
-
|
|
51
|
-
SAMPLE_BATCH_SIZE = 32 # Number of samples to process in a single batch
|
|
52
|
-
|
|
45
|
+
DEFAULT_DATASET_NAME = "default_dataset"
|
|
53
46
|
|
|
54
|
-
|
|
55
|
-
class AnnotationProcessingContext:
|
|
56
|
-
"""Context for processing annotations for a single sample."""
|
|
57
|
-
|
|
58
|
-
dataset_id: UUID
|
|
59
|
-
sample_id: UUID
|
|
60
|
-
label_map: dict[int, UUID]
|
|
61
|
-
annotation_task_id: UUID
|
|
47
|
+
_SliceType = slice # to avoid shadowing built-in slice in type annotations
|
|
62
48
|
|
|
63
49
|
|
|
64
50
|
class Dataset:
|
|
65
51
|
"""A LightlyStudio Dataset.
|
|
66
52
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
name: The name of the dataset. If None, a default name will be assigned.
|
|
53
|
+
Keeps a reference to the underlying DatasetTable.
|
|
71
54
|
"""
|
|
72
55
|
|
|
73
|
-
def __init__(self,
|
|
56
|
+
def __init__(self, dataset: DatasetTable) -> None:
|
|
74
57
|
"""Initialize a LightlyStudio Dataset."""
|
|
58
|
+
self._inner = dataset
|
|
59
|
+
# TODO(Michal, 09/2025): Do not store the session. Instead, use the
|
|
60
|
+
# dataset object session.
|
|
61
|
+
self.session = db_manager.persistent_session()
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def create(name: str | None = None) -> Dataset:
|
|
65
|
+
"""Create a new dataset."""
|
|
66
|
+
if name is None:
|
|
67
|
+
name = DEFAULT_DATASET_NAME
|
|
68
|
+
|
|
69
|
+
dataset = dataset_resolver.create(
|
|
70
|
+
session=db_manager.persistent_session(),
|
|
71
|
+
dataset=DatasetCreate(name=name, directory=""),
|
|
72
|
+
)
|
|
73
|
+
return Dataset(dataset=dataset)
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def load(name: str | None = None) -> Dataset:
|
|
77
|
+
"""Load an existing dataset."""
|
|
75
78
|
if name is None:
|
|
76
79
|
name = "default_dataset"
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
80
|
+
|
|
81
|
+
dataset = dataset_resolver.get_by_name(session=db_manager.persistent_session(), name=name)
|
|
82
|
+
if dataset is None:
|
|
83
|
+
raise ValueError(f"Dataset with name '{name}' not found.")
|
|
84
|
+
|
|
85
|
+
return Dataset(dataset=dataset)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def load_or_create(name: str | None = None) -> Dataset:
|
|
89
|
+
"""Create a new dataset or load an existing one."""
|
|
90
|
+
if name is None:
|
|
91
|
+
name = "default_dataset"
|
|
92
|
+
|
|
93
|
+
dataset = dataset_resolver.get_by_name(session=db_manager.persistent_session(), name=name)
|
|
94
|
+
if dataset is None:
|
|
95
|
+
return Dataset.create(name=name)
|
|
96
|
+
|
|
97
|
+
return Dataset(dataset=dataset)
|
|
98
|
+
|
|
99
|
+
def __iter__(self) -> Iterator[Sample]:
|
|
100
|
+
"""Iterate over samples in the dataset."""
|
|
101
|
+
for sample in self.session.exec(
|
|
102
|
+
select(SampleTable).where(SampleTable.dataset_id == self.dataset_id)
|
|
103
|
+
):
|
|
104
|
+
yield Sample(inner=sample)
|
|
105
|
+
|
|
106
|
+
def get_sample(self, sample_id: UUID) -> Sample:
|
|
107
|
+
"""Get a single sample from the dataset by its ID.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
sample_id: The UUID of the sample to retrieve.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A single SampleTable object.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
IndexError: If no sample is found with the given sample_id.
|
|
117
|
+
"""
|
|
118
|
+
sample = sample_resolver.get_by_id(
|
|
119
|
+
self.session, dataset_id=self.dataset_id, sample_id=sample_id
|
|
86
120
|
)
|
|
87
121
|
|
|
122
|
+
if sample is None:
|
|
123
|
+
raise IndexError(f"No sample found for sample_id: {sample_id}")
|
|
124
|
+
return Sample(inner=sample)
|
|
125
|
+
|
|
88
126
|
@property
|
|
89
127
|
def dataset_id(self) -> UUID:
|
|
90
128
|
"""Get the dataset ID."""
|
|
91
|
-
return self.
|
|
129
|
+
return self._inner.dataset_id
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def name(self) -> str:
|
|
133
|
+
"""Get the dataset name."""
|
|
134
|
+
return self._inner.name
|
|
135
|
+
|
|
136
|
+
def query(self) -> DatasetQuery:
|
|
137
|
+
"""Create a DatasetQuery for this dataset.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
A DatasetQuery instance for querying samples in this dataset.
|
|
141
|
+
"""
|
|
142
|
+
return DatasetQuery(dataset=self._inner, session=self.session)
|
|
143
|
+
|
|
144
|
+
def match(self, match_expression: MatchExpression) -> DatasetQuery:
|
|
145
|
+
"""Create a query on the dataset and store a field condition for filtering.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
match_expression: Defines the filter.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
DatasetQuery for method chaining.
|
|
152
|
+
"""
|
|
153
|
+
return self.query().match(match_expression)
|
|
154
|
+
|
|
155
|
+
def order_by(self, *order_by: OrderByExpression) -> DatasetQuery:
|
|
156
|
+
"""Create a query on the dataset and store ordering expressions.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
order_by: One or more ordering expressions. They are applied in order.
|
|
160
|
+
E.g. first ordering by sample width and then by sample file_name will
|
|
161
|
+
only order the samples with the same sample width by file_name.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
DatasetQuery for method chaining.
|
|
165
|
+
"""
|
|
166
|
+
return self.query().order_by(*order_by)
|
|
167
|
+
|
|
168
|
+
def slice(self, offset: int = 0, limit: int | None = None) -> DatasetQuery:
|
|
169
|
+
"""Create a query on the dataset and apply offset and limit to results.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
offset: Number of items to skip from beginning (default: 0).
|
|
173
|
+
limit: Maximum number of items to return (None = no limit).
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
DatasetQuery for method chaining.
|
|
177
|
+
"""
|
|
178
|
+
return self.query().slice(offset, limit)
|
|
179
|
+
|
|
180
|
+
def __getitem__(self, key: _SliceType) -> DatasetQuery:
|
|
181
|
+
"""Create a query on the dataset and enable bracket notation for slicing.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
key: A slice object (e.g., [10:20], [:50], [100:]).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
DatasetQuery with slice applied.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
TypeError: If key is not a slice object.
|
|
191
|
+
ValueError: If slice contains unsupported features or conflicts with existing slice.
|
|
192
|
+
"""
|
|
193
|
+
return self.query()[key]
|
|
92
194
|
|
|
93
195
|
def add_samples_from_path(
|
|
94
196
|
self,
|
|
95
197
|
path: PathLike,
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
".png",
|
|
99
|
-
".jpg",
|
|
100
|
-
".jpeg",
|
|
101
|
-
".gif",
|
|
102
|
-
".webp",
|
|
103
|
-
".bmp",
|
|
104
|
-
".tiff",
|
|
105
|
-
},
|
|
198
|
+
allowed_extensions: Iterable[str] | None = None,
|
|
199
|
+
embed: bool = True,
|
|
106
200
|
) -> None:
|
|
107
201
|
"""Adding samples from the specified path to the dataset.
|
|
108
202
|
|
|
109
203
|
Args:
|
|
110
204
|
path: Path to the folder containing the images to add.
|
|
111
|
-
recursive: If True, search for images recursively in subfolders.
|
|
112
205
|
allowed_extensions: An iterable container of allowed image file
|
|
113
206
|
extensions.
|
|
207
|
+
embed: If True, generate embeddings for the newly added samples.
|
|
114
208
|
"""
|
|
115
|
-
path = Path(path).absolute() if isinstance(path, str) else path.absolute()
|
|
116
|
-
if not path.exists() or not path.is_dir():
|
|
117
|
-
raise ValueError(f"Provided path is not a valid directory: {path}")
|
|
118
|
-
|
|
119
209
|
# Collect image file paths.
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
210
|
+
if allowed_extensions:
|
|
211
|
+
allowed_extensions_set = {ext.lower() for ext in allowed_extensions}
|
|
212
|
+
else:
|
|
213
|
+
allowed_extensions_set = None
|
|
214
|
+
image_paths = list(
|
|
215
|
+
fsspec_lister.iter_files_from_path(
|
|
216
|
+
path=str(path), allowed_extensions=allowed_extensions_set
|
|
217
|
+
)
|
|
218
|
+
)
|
|
126
219
|
print(f"Found {len(image_paths)} images in {path}.")
|
|
127
220
|
|
|
128
221
|
# Process images.
|
|
129
|
-
|
|
222
|
+
created_sample_ids = add_samples.load_into_dataset_from_paths(
|
|
130
223
|
session=self.session,
|
|
131
224
|
dataset_id=self.dataset_id,
|
|
132
225
|
image_paths=image_paths,
|
|
133
226
|
)
|
|
134
227
|
|
|
228
|
+
if embed:
|
|
229
|
+
_generate_embeddings(
|
|
230
|
+
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
|
|
231
|
+
)
|
|
232
|
+
|
|
135
233
|
def add_samples_from_labelformat(
|
|
136
234
|
self,
|
|
137
235
|
input_labels: ObjectDetectionInput | InstanceSegmentationInput,
|
|
138
236
|
images_path: PathLike,
|
|
139
237
|
is_prediction: bool = True,
|
|
140
238
|
task_name: str | None = None,
|
|
239
|
+
embed: bool = True,
|
|
141
240
|
) -> None:
|
|
142
241
|
"""Load a dataset from a labelformat object and store in database.
|
|
143
242
|
|
|
@@ -147,9 +246,7 @@ class Dataset:
|
|
|
147
246
|
is_prediction: Whether the task is for prediction or labels.
|
|
148
247
|
task_name: Optional name for the annotation task. If None, a
|
|
149
248
|
default name is generated.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
DatasetTable: The created dataset table entry.
|
|
249
|
+
embed: If True, generate embeddings for the newly added samples.
|
|
153
250
|
"""
|
|
154
251
|
if isinstance(images_path, str):
|
|
155
252
|
images_path = Path(images_path)
|
|
@@ -174,7 +271,7 @@ class Dataset:
|
|
|
174
271
|
),
|
|
175
272
|
)
|
|
176
273
|
|
|
177
|
-
|
|
274
|
+
created_sample_ids = add_samples.load_into_dataset_from_labelformat(
|
|
178
275
|
session=self.session,
|
|
179
276
|
dataset_id=self.dataset_id,
|
|
180
277
|
input_labels=input_labels,
|
|
@@ -182,25 +279,33 @@ class Dataset:
|
|
|
182
279
|
annotation_task_id=new_annotation_task.annotation_task_id,
|
|
183
280
|
)
|
|
184
281
|
|
|
185
|
-
|
|
282
|
+
if embed:
|
|
283
|
+
_generate_embeddings(
|
|
284
|
+
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def add_samples_from_yolo(
|
|
186
288
|
self,
|
|
187
|
-
|
|
289
|
+
data_yaml: PathLike,
|
|
188
290
|
input_split: str = "train",
|
|
189
291
|
task_name: str | None = None,
|
|
190
|
-
|
|
292
|
+
embed: bool = True,
|
|
293
|
+
) -> None:
|
|
191
294
|
"""Load a dataset in YOLO format and store in DB.
|
|
192
295
|
|
|
193
296
|
Args:
|
|
194
|
-
|
|
297
|
+
data_yaml: Path to the YOLO data.yaml file.
|
|
195
298
|
input_split: The split to load (e.g., 'train', 'val').
|
|
196
299
|
task_name: Optional name for the annotation task. If None, a
|
|
197
300
|
default name is generated.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
DatasetTable: The created dataset table entry.
|
|
301
|
+
embed: If True, generate embeddings for the newly added samples.
|
|
201
302
|
"""
|
|
202
|
-
data_yaml
|
|
203
|
-
|
|
303
|
+
if isinstance(data_yaml, str):
|
|
304
|
+
data_yaml = Path(data_yaml)
|
|
305
|
+
data_yaml = data_yaml.absolute()
|
|
306
|
+
|
|
307
|
+
if not data_yaml.is_file() or data_yaml.suffix != ".yaml":
|
|
308
|
+
raise FileNotFoundError(f"YOLO data yaml file not found: '{data_yaml}'")
|
|
204
309
|
|
|
205
310
|
if task_name is None:
|
|
206
311
|
task_name = f"Loaded from YOLO: {data_yaml.name} ({input_split} split)"
|
|
@@ -210,314 +315,97 @@ class Dataset:
|
|
|
210
315
|
input_file=data_yaml,
|
|
211
316
|
input_split=input_split,
|
|
212
317
|
)
|
|
213
|
-
|
|
318
|
+
images_path = label_input._images_dir() # noqa: SLF001
|
|
214
319
|
|
|
215
|
-
|
|
320
|
+
self.add_samples_from_labelformat(
|
|
216
321
|
input_labels=label_input,
|
|
217
|
-
|
|
218
|
-
img_dir=str(img_dir),
|
|
322
|
+
images_path=images_path,
|
|
219
323
|
is_prediction=False,
|
|
220
324
|
task_name=task_name,
|
|
325
|
+
embed=embed,
|
|
221
326
|
)
|
|
222
327
|
|
|
223
|
-
def
|
|
328
|
+
def add_samples_from_coco(
|
|
224
329
|
self,
|
|
225
|
-
|
|
226
|
-
|
|
330
|
+
annotations_json: PathLike,
|
|
331
|
+
images_path: PathLike,
|
|
227
332
|
task_name: str | None = None,
|
|
228
|
-
|
|
333
|
+
annotation_type: AnnotationType = AnnotationType.BBOX,
|
|
334
|
+
embed: bool = True,
|
|
335
|
+
) -> None:
|
|
229
336
|
"""Load a dataset in COCO Object Detection format and store in DB.
|
|
230
337
|
|
|
231
338
|
Args:
|
|
232
|
-
|
|
233
|
-
|
|
339
|
+
annotations_json: Path to the COCO annotations JSON file.
|
|
340
|
+
images_path: Path to the folder containing the images.
|
|
234
341
|
task_name: Optional name for the annotation task. If None, a
|
|
235
342
|
default name is generated.
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
343
|
+
annotation_type: The type of annotation to be loaded (e.g., 'ObjectDetection',
|
|
344
|
+
'InstanceSegmentation').
|
|
345
|
+
embed: If True, generate embeddings for the newly added samples.
|
|
239
346
|
"""
|
|
240
|
-
annotations_json
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
if task_name is None:
|
|
244
|
-
task_name = f"Loaded from COCO Object Detection: {annotations_json.name}"
|
|
245
|
-
|
|
246
|
-
label_input = COCOObjectDetectionInput(
|
|
247
|
-
input_file=annotations_json,
|
|
248
|
-
)
|
|
249
|
-
img_dir_path = Path(img_dir).absolute()
|
|
347
|
+
if isinstance(annotations_json, str):
|
|
348
|
+
annotations_json = Path(annotations_json)
|
|
349
|
+
annotations_json = annotations_json.absolute()
|
|
250
350
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
dataset_name=dataset_name,
|
|
254
|
-
img_dir=str(img_dir_path),
|
|
255
|
-
is_prediction=False,
|
|
256
|
-
task_name=task_name,
|
|
257
|
-
)
|
|
351
|
+
if not annotations_json.is_file() or annotations_json.suffix != ".json":
|
|
352
|
+
raise FileNotFoundError(f"COCO annotations json file not found: '{annotations_json}'")
|
|
258
353
|
|
|
259
|
-
|
|
260
|
-
self,
|
|
261
|
-
annotations_json_path: str,
|
|
262
|
-
img_dir: str,
|
|
263
|
-
task_name: str | None = None,
|
|
264
|
-
) -> DatasetTable:
|
|
265
|
-
"""Load a dataset in COCO Instance Segmentation format and store in DB.
|
|
354
|
+
label_input: COCOObjectDetectionInput | COCOInstanceSegmentationInput
|
|
266
355
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
356
|
+
if annotation_type == AnnotationType.BBOX:
|
|
357
|
+
label_input = COCOObjectDetectionInput(
|
|
358
|
+
input_file=annotations_json,
|
|
359
|
+
)
|
|
360
|
+
task_name_default = f"Loaded from COCO Object Detection: {annotations_json.name}"
|
|
361
|
+
elif annotation_type == AnnotationType.INSTANCE_SEGMENTATION:
|
|
362
|
+
label_input = COCOInstanceSegmentationInput(
|
|
363
|
+
input_file=annotations_json,
|
|
364
|
+
)
|
|
365
|
+
task_name_default = f"Loaded from COCO Instance Segmentation: {annotations_json.name}"
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError(f"Invalid annotation type: {annotation_type}")
|
|
278
368
|
|
|
279
369
|
if task_name is None:
|
|
280
|
-
task_name =
|
|
370
|
+
task_name = task_name_default
|
|
281
371
|
|
|
282
|
-
|
|
283
|
-
input_file=annotations_json,
|
|
284
|
-
)
|
|
285
|
-
img_dir_path = Path(img_dir).absolute()
|
|
372
|
+
images_path = Path(images_path).absolute()
|
|
286
373
|
|
|
287
|
-
|
|
374
|
+
self.add_samples_from_labelformat(
|
|
288
375
|
input_labels=label_input,
|
|
289
|
-
|
|
290
|
-
img_dir=str(img_dir_path),
|
|
376
|
+
images_path=images_path,
|
|
291
377
|
is_prediction=False,
|
|
292
378
|
task_name=task_name,
|
|
379
|
+
embed=embed,
|
|
293
380
|
)
|
|
294
381
|
|
|
295
|
-
@staticmethod
|
|
296
|
-
def load_from_db(name: str, db_path: PathLike) -> Dataset:
|
|
297
|
-
"""Load a dataset from the database.
|
|
298
382
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
"""
|
|
302
|
-
raise NotImplementedError
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
def _load_into_dataset_from_paths(
|
|
306
|
-
dataset_id: UUID,
|
|
307
|
-
session: Session,
|
|
308
|
-
image_paths: Iterable[Path],
|
|
309
|
-
) -> None:
|
|
310
|
-
samples_to_create: list[SampleCreate] = []
|
|
311
|
-
|
|
312
|
-
for image_path in tqdm(
|
|
313
|
-
image_paths,
|
|
314
|
-
desc="Processing images",
|
|
315
|
-
unit=" images",
|
|
316
|
-
):
|
|
317
|
-
try:
|
|
318
|
-
image = PIL.Image.open(image_path)
|
|
319
|
-
width, height = image.size
|
|
320
|
-
image.close()
|
|
321
|
-
except (FileNotFoundError, PIL.UnidentifiedImageError, OSError):
|
|
322
|
-
continue
|
|
323
|
-
|
|
324
|
-
sample = SampleCreate(
|
|
325
|
-
file_name=image_path.name,
|
|
326
|
-
file_path_abs=str(image_path),
|
|
327
|
-
width=width,
|
|
328
|
-
height=height,
|
|
329
|
-
dataset_id=dataset_id,
|
|
330
|
-
)
|
|
331
|
-
samples_to_create.append(sample)
|
|
332
|
-
|
|
333
|
-
# Process batch when it reaches SAMPLE_BATCH_SIZE
|
|
334
|
-
if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
|
|
335
|
-
_ = sample_resolver.create_many(session=session, samples=samples_to_create)
|
|
336
|
-
samples_to_create = []
|
|
337
|
-
|
|
338
|
-
# Handle remaining samples
|
|
339
|
-
if samples_to_create:
|
|
340
|
-
_ = sample_resolver.create_many(session=session, samples=samples_to_create)
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def _load_into_dataset(
|
|
344
|
-
session: Session,
|
|
345
|
-
dataset_id: UUID,
|
|
346
|
-
input_labels: ObjectDetectionInput | InstanceSegmentationInput,
|
|
347
|
-
images_path: Path,
|
|
348
|
-
annotation_task_id: UUID,
|
|
349
|
-
) -> None:
|
|
350
|
-
"""Store a loaded dataset in database."""
|
|
351
|
-
# Create label mapping
|
|
352
|
-
label_map = _create_label_map(session=session, input_labels=input_labels)
|
|
353
|
-
|
|
354
|
-
annotations_to_create: list[AnnotationCreate] = []
|
|
355
|
-
sample_ids: list[UUID] = []
|
|
356
|
-
samples_to_create: list[SampleCreate] = []
|
|
357
|
-
samples_image_data: list[
|
|
358
|
-
tuple[SampleCreate, ImageInstanceSegmentation | ImageObjectDetection]
|
|
359
|
-
] = []
|
|
360
|
-
|
|
361
|
-
for image_data in tqdm(input_labels.get_labels(), desc="Processing images", unit=" images"):
|
|
362
|
-
image: Image = image_data.image # type: ignore[attr-defined]
|
|
363
|
-
|
|
364
|
-
typed_image_data: ImageInstanceSegmentation | ImageObjectDetection = image_data # type: ignore[assignment]
|
|
365
|
-
sample = SampleCreate(
|
|
366
|
-
file_name=str(image.filename),
|
|
367
|
-
file_path_abs=str(images_path / image.filename),
|
|
368
|
-
width=image.width,
|
|
369
|
-
height=image.height,
|
|
370
|
-
dataset_id=dataset_id,
|
|
371
|
-
)
|
|
372
|
-
samples_to_create.append(sample)
|
|
373
|
-
samples_image_data.append((sample, typed_image_data))
|
|
374
|
-
|
|
375
|
-
if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
|
|
376
|
-
stored_samples = sample_resolver.create_many(session=session, samples=samples_to_create)
|
|
377
|
-
_process_batch_annotations(
|
|
378
|
-
session=session,
|
|
379
|
-
stored_samples=stored_samples,
|
|
380
|
-
samples_data=samples_image_data,
|
|
381
|
-
dataset_id=dataset_id,
|
|
382
|
-
label_map=label_map,
|
|
383
|
-
annotation_task_id=annotation_task_id,
|
|
384
|
-
annotations_to_create=annotations_to_create,
|
|
385
|
-
sample_ids=sample_ids,
|
|
386
|
-
)
|
|
387
|
-
samples_to_create.clear()
|
|
388
|
-
samples_image_data.clear()
|
|
389
|
-
|
|
390
|
-
if samples_to_create:
|
|
391
|
-
stored_samples = sample_resolver.create_many(session=session, samples=samples_to_create)
|
|
392
|
-
_process_batch_annotations(
|
|
393
|
-
session=session,
|
|
394
|
-
stored_samples=stored_samples,
|
|
395
|
-
samples_data=samples_image_data,
|
|
396
|
-
dataset_id=dataset_id,
|
|
397
|
-
label_map=label_map,
|
|
398
|
-
annotation_task_id=annotation_task_id,
|
|
399
|
-
annotations_to_create=annotations_to_create,
|
|
400
|
-
sample_ids=sample_ids,
|
|
401
|
-
)
|
|
402
|
-
|
|
403
|
-
# Insert any remaining annotations
|
|
404
|
-
if annotations_to_create:
|
|
405
|
-
annotation_resolver.create_many(session=session, annotations=annotations_to_create)
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def _create_label_map(
|
|
409
|
-
session: Session,
|
|
410
|
-
input_labels: ObjectDetectionInput | InstanceSegmentationInput,
|
|
411
|
-
) -> dict[int, UUID]:
|
|
412
|
-
"""Create a mapping of category IDs to annotation label IDs."""
|
|
413
|
-
label_map = {}
|
|
414
|
-
for category in tqdm(
|
|
415
|
-
input_labels.get_categories(),
|
|
416
|
-
desc="Processing categories",
|
|
417
|
-
unit=" categories",
|
|
418
|
-
):
|
|
419
|
-
label = AnnotationLabelCreate(annotation_label_name=category.name)
|
|
420
|
-
stored_label = annotation_label_resolver.create(session=session, label=label)
|
|
421
|
-
label_map[category.id] = stored_label.annotation_label_id
|
|
422
|
-
return label_map
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
def _process_object_detection_annotations(
|
|
426
|
-
context: AnnotationProcessingContext,
|
|
427
|
-
image_data: ImageObjectDetection,
|
|
428
|
-
) -> list[AnnotationCreate]:
|
|
429
|
-
"""Process object detection annotations for a single image."""
|
|
430
|
-
new_annotations = []
|
|
431
|
-
for obj in image_data.objects:
|
|
432
|
-
box = obj.box.to_format(BoundingBoxFormat.XYWH)
|
|
433
|
-
x, y, width, height = box
|
|
434
|
-
|
|
435
|
-
new_annotations.append(
|
|
436
|
-
AnnotationCreate(
|
|
437
|
-
dataset_id=context.dataset_id,
|
|
438
|
-
sample_id=context.sample_id,
|
|
439
|
-
annotation_label_id=context.label_map[obj.category.id],
|
|
440
|
-
annotation_type="object_detection",
|
|
441
|
-
x=x,
|
|
442
|
-
y=y,
|
|
443
|
-
width=width,
|
|
444
|
-
height=height,
|
|
445
|
-
confidence=obj.confidence,
|
|
446
|
-
annotation_task_id=context.annotation_task_id,
|
|
447
|
-
)
|
|
448
|
-
)
|
|
449
|
-
return new_annotations
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
def _process_instance_segmentation_annotations(
|
|
453
|
-
context: AnnotationProcessingContext,
|
|
454
|
-
image_data: ImageInstanceSegmentation,
|
|
455
|
-
) -> list[AnnotationCreate]:
|
|
456
|
-
"""Process instance segmentation annotations for a single image."""
|
|
457
|
-
new_annotations = []
|
|
458
|
-
for obj in image_data.objects:
|
|
459
|
-
segmentation_rle: None | list[int] = None
|
|
460
|
-
if isinstance(obj.segmentation, MultiPolygon):
|
|
461
|
-
box = obj.segmentation.bounding_box().to_format(BoundingBoxFormat.XYWH)
|
|
462
|
-
elif isinstance(obj.segmentation, BinaryMaskSegmentation):
|
|
463
|
-
box = obj.segmentation.bounding_box.to_format(BoundingBoxFormat.XYWH)
|
|
464
|
-
segmentation_rle = obj.segmentation._rle_row_wise # noqa: SLF001
|
|
465
|
-
else:
|
|
466
|
-
raise ValueError(f"Unsupported segmentation type: {type(obj.segmentation)}")
|
|
467
|
-
|
|
468
|
-
x, y, width, height = box
|
|
469
|
-
|
|
470
|
-
new_annotations.append(
|
|
471
|
-
AnnotationCreate(
|
|
472
|
-
dataset_id=context.dataset_id,
|
|
473
|
-
sample_id=context.sample_id,
|
|
474
|
-
annotation_label_id=context.label_map[obj.category.id],
|
|
475
|
-
annotation_type="instance_segmentation",
|
|
476
|
-
x=x,
|
|
477
|
-
y=y,
|
|
478
|
-
width=width,
|
|
479
|
-
height=height,
|
|
480
|
-
segmentation_mask=segmentation_rle,
|
|
481
|
-
annotation_task_id=context.annotation_task_id,
|
|
482
|
-
)
|
|
483
|
-
)
|
|
484
|
-
return new_annotations
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
def _process_batch_annotations( # noqa: PLR0913
|
|
488
|
-
session: Session,
|
|
489
|
-
stored_samples: list[SampleTable],
|
|
490
|
-
samples_data: list[tuple[SampleCreate, ImageInstanceSegmentation | ImageObjectDetection]],
|
|
491
|
-
dataset_id: UUID,
|
|
492
|
-
label_map: dict[int, UUID],
|
|
493
|
-
annotation_task_id: UUID,
|
|
494
|
-
annotations_to_create: list[AnnotationCreate],
|
|
495
|
-
sample_ids: list[UUID],
|
|
496
|
-
) -> None:
|
|
497
|
-
"""Process annotations for a batch of samples."""
|
|
498
|
-
for stored_sample, (_, img_data) in zip(stored_samples, samples_data):
|
|
499
|
-
sample_ids.append(stored_sample.sample_id)
|
|
500
|
-
|
|
501
|
-
context = AnnotationProcessingContext(
|
|
502
|
-
dataset_id=dataset_id,
|
|
503
|
-
sample_id=stored_sample.sample_id,
|
|
504
|
-
label_map=label_map,
|
|
505
|
-
annotation_task_id=annotation_task_id,
|
|
506
|
-
)
|
|
383
|
+
def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
|
|
384
|
+
"""Generate and store embeddings for samples.
|
|
507
385
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
386
|
+
Args:
|
|
387
|
+
session: Database session for resolver operations.
|
|
388
|
+
dataset_id: The ID of the dataset to associate with the embedding model.
|
|
389
|
+
sample_ids: List of sample IDs to generate embeddings for.
|
|
390
|
+
"""
|
|
391
|
+
if not sample_ids:
|
|
392
|
+
return
|
|
393
|
+
|
|
394
|
+
embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
|
|
395
|
+
model_id = embedding_manager.load_or_get_default_model(
|
|
396
|
+
session=session,
|
|
397
|
+
dataset_id=dataset_id,
|
|
398
|
+
)
|
|
399
|
+
if model_id is None:
|
|
400
|
+
print("No embedding model loaded. Skipping embedding generation.")
|
|
401
|
+
return
|
|
402
|
+
|
|
403
|
+
embedding_manager.embed_images(
|
|
404
|
+
session=session,
|
|
405
|
+
sample_ids=sample_ids,
|
|
406
|
+
embedding_model_id=model_id,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# Mark the embedding search feature as enabled.
|
|
410
|
+
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
|
|
411
|
+
features.lightly_studio_active_features.append("embeddingSearchEnabled")
|