lightly-studio 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (122) hide show
  1. lightly_studio/api/app.py +2 -0
  2. lightly_studio/api/routes/api/caption.py +30 -0
  3. lightly_studio/api/routes/api/embeddings2d.py +36 -4
  4. lightly_studio/api/routes/api/metadata.py +57 -1
  5. lightly_studio/core/add_samples.py +138 -0
  6. lightly_studio/core/dataset.py +143 -16
  7. lightly_studio/dataset/loader.py +2 -8
  8. lightly_studio/db_manager.py +10 -4
  9. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.B3oFNb6O.css +1 -0
  10. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  11. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/Samples.CIbricz7.css +1 -0
  12. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.7Ma7YdVg.css +1 -0
  13. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{useFeatureFlags.CV-KWLNP.css → _layout.CefECEWA.css} +1 -1
  14. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.2jKMtOWG.css +1 -0
  15. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/-DXuGN29.js +1 -0
  16. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Cs1XmhiF.js → B7302SU7.js} +1 -1
  17. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BeWf8-vJ.js +1 -0
  18. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bqz7dyEC.js +1 -0
  19. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  20. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BdfTHw61.js → CSCQddQS.js} +1 -1
  21. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CZGpyrcA.js +1 -0
  22. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CfQ4mGwl.js +1 -0
  23. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CiaNZCBa.js +1 -0
  24. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cqo0Vpvt.js +417 -0
  25. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cy4fgWTG.js +1 -0
  26. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5w4xp5l.js +1 -0
  27. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DD63uD-T.js +1 -0
  28. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQ8aZ1o-.js +3 -0
  29. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{keKYsoph.js → DSxvnAMh.js} +1 -1
  30. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_JuJOO3.js +20 -0
  31. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_ynJAfY.js +2 -0
  32. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dafy4oEQ.js +1 -0
  33. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BfHVnyNT.js → Dj4O-5se.js} +1 -1
  34. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmjAI-UV.js +1 -0
  35. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dug7Bq1S.js +1 -0
  36. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dv5BSBQG.js +1 -0
  37. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzBTnFhV.js +1 -0
  38. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzX_yyqb.js +1 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Frwd2CjB.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H4l0JFh9.js +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H60ATh8g.js +2 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{6t3IJ0vQ.js → qIv1kPyv.js} +1 -1
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/sLqs1uaK.js +20 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/u-it74zV.js +96 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BPc0HQPq.js +2 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.SNvc2nrm.js +1 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.5jT7P06o.js +1 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.Cdy-7S5q.js +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.C_uoESTX.js +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DcO8wIAc.js +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{2.C8HLK8mj.js → 2.BIldfkxL.js} +268 -113
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CLvg3QcJ.js → 3.BC9z_TWM.js} +1 -1
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.BQhDtXUI.js → 4.D8X_Ch5n.js} +1 -1
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.CAXhxJu6.js +39 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{6.uBV1Lhat.js → 6.DRA5Ru_2.js} +1 -1
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.WVBsruHQ.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BuKUrCEN.js +20 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CUIn1yCR.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
  60. lightly_studio/dist_lightly_studio_view_app/index.html +15 -14
  61. lightly_studio/examples/example.py +4 -0
  62. lightly_studio/examples/example_coco.py +4 -0
  63. lightly_studio/examples/example_coco_caption.py +24 -0
  64. lightly_studio/examples/example_metadata.py +4 -1
  65. lightly_studio/examples/example_selection.py +4 -0
  66. lightly_studio/examples/example_split_work.py +4 -0
  67. lightly_studio/examples/example_yolo.py +4 -0
  68. lightly_studio/export/export_dataset.py +11 -3
  69. lightly_studio/metadata/compute_typicality.py +1 -1
  70. lightly_studio/models/caption.py +73 -0
  71. lightly_studio/models/dataset.py +1 -2
  72. lightly_studio/models/metadata.py +1 -1
  73. lightly_studio/models/sample.py +2 -2
  74. lightly_studio/resolvers/caption_resolver.py +80 -0
  75. lightly_studio/resolvers/dataset_resolver.py +4 -7
  76. lightly_studio/resolvers/metadata_resolver/__init__.py +2 -2
  77. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +3 -3
  78. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  79. lightly_studio/resolvers/samples_filter.py +18 -10
  80. lightly_studio/type_definitions.py +2 -0
  81. {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/METADATA +86 -21
  82. {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/RECORD +83 -77
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.CA_CXIBb.css +0 -1
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.DS78jgNY.css +0 -1
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/index.BVs_sZj9.css +0 -1
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.D487hwJk.css +0 -1
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/8NsknIT2.js +0 -1
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BND_-4Kp.js +0 -1
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BjkP1AHA.js +0 -1
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BuuNVL9G.js +0 -1
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BzKGpnl4.js +0 -1
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CCx7Ho51.js +0 -1
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CH6P3X75.js +0 -1
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CR2upx_Q.js +0 -4
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWPZrTTJ.js +0 -1
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CwPowJfP.js +0 -1
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxFKfZ9T.js +0 -1
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cxevwdid.js +0 -1
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4whDBUi.js +0 -1
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6r9vr07.js +0 -1
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DA6bFLPR.js +0 -1
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DEgUu98i.js +0 -3
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DGTPl6Gk.js +0 -1
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DKGxBSlK.js +0 -1
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQXoLcsF.js +0 -1
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQe_kdRt.js +0 -92
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DcY4jgG3.js +0 -1
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +0 -1
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/RmD8FzRo.js +0 -1
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/V-MnMC1X.js +0 -1
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BVr6DYqP.js +0 -2
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.u7zsVvqp.js +0 -1
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.Da2agmdd.js +0 -1
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B11tVRJV.js +0 -1
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.l30Zud4h.js +0 -1
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CgKPGcAP.js +0 -1
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.-6XqWX5G.js +0 -1
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.BXsgoQZh.js +0 -1
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BkbcnUs8.js +0 -1
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.Bkrv-Vww.js +0 -1
  121. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +0 -48
  122. {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/WHEEL +0 -0
lightly_studio/api/app.py CHANGED
@@ -16,6 +16,7 @@ from lightly_studio.api.routes import healthz, images, webapp
16
16
  from lightly_studio.api.routes.api import (
17
17
  annotation,
18
18
  annotation_label,
19
+ caption,
19
20
  classifier,
20
21
  dataset,
21
22
  dataset_tag,
@@ -89,6 +90,7 @@ api_router.include_router(export.export_router)
89
90
  api_router.include_router(sample.samples_router)
90
91
  api_router.include_router(annotation_label.annotations_label_router)
91
92
  api_router.include_router(annotation.annotations_router)
93
+ api_router.include_router(caption.captions_router)
92
94
  api_router.include_router(text_embedding.text_embedding_router)
93
95
  api_router.include_router(settings.settings_router)
94
96
  api_router.include_router(classifier.classifier_router)
@@ -0,0 +1,30 @@
1
+ """API routes for dataset captions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from uuid import UUID
6
+
7
+ from fastapi import APIRouter, Depends, Path
8
+ from typing_extensions import Annotated
9
+
10
+ from lightly_studio.api.routes.api.validators import Paginated, PaginatedWithCursor
11
+ from lightly_studio.db_manager import SessionDep
12
+ from lightly_studio.models.caption import CaptionsListView
13
+ from lightly_studio.resolvers import caption_resolver
14
+ from lightly_studio.resolvers.caption_resolver import GetAllCaptionsResult
15
+
16
+ captions_router = APIRouter(prefix="/datasets/{dataset_id}", tags=["captions"])
17
+
18
+
19
+ @captions_router.get("/captions", response_model=CaptionsListView)
20
+ def read_captions(
21
+ dataset_id: Annotated[UUID, Path(title="Dataset Id")],
22
+ session: SessionDep,
23
+ pagination: Annotated[PaginatedWithCursor, Depends()],
24
+ ) -> GetAllCaptionsResult:
25
+ """Retrieve captions for a dataset."""
26
+ return caption_resolver.get_all(
27
+ session=session,
28
+ dataset_id=dataset_id,
29
+ pagination=Paginated(offset=pagination.offset, limit=pagination.limit),
30
+ )
@@ -3,25 +3,40 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import io
6
+ from uuid import UUID
6
7
 
7
8
  import numpy as np
8
9
  import pyarrow as pa
9
10
  from fastapi import APIRouter, HTTPException, Response
10
11
  from numpy.typing import NDArray
11
12
  from pyarrow import ipc
13
+ from pydantic import BaseModel, Field
12
14
  from sklearn.manifold import TSNE
13
15
  from sqlmodel import select
14
16
 
15
17
  from lightly_studio.db_manager import SessionDep
16
18
  from lightly_studio.models.dataset import DatasetTable
17
19
  from lightly_studio.models.embedding_model import EmbeddingModelTable
18
- from lightly_studio.resolvers import sample_embedding_resolver
20
+ from lightly_studio.resolvers import sample_embedding_resolver, sample_resolver
21
+ from lightly_studio.resolvers.samples_filter import SampleFilter
19
22
 
20
23
  embeddings2d_router = APIRouter()
21
24
 
22
25
 
23
- @embeddings2d_router.get("/embeddings2d/tsne")
24
- def get_embeddings2d__tsne(session: SessionDep) -> Response:
26
+ class GetEmbeddings2DRequest(BaseModel):
27
+ """Request body for retrieving 2D embeddings."""
28
+
29
+ filters: SampleFilter | None = Field(
30
+ None,
31
+ description="Filter parameters identifying matching samples",
32
+ )
33
+
34
+
35
+ @embeddings2d_router.post("/embeddings2d/tsne")
36
+ def get_embeddings2d__tsne(
37
+ session: SessionDep,
38
+ body: GetEmbeddings2DRequest | None = None,
39
+ ) -> Response:
25
40
  """Return 2D embeddings serialized as an Arrow stream."""
26
41
  # TODO(Malte, 09/2025): Support choosing the dataset via API parameter.
27
42
  dataset = session.exec(select(DatasetTable).limit(1)).first()
@@ -37,7 +52,6 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
37
52
  if embedding_model is None:
38
53
  raise HTTPException(status_code=404, detail="No embedding model configured.")
39
54
 
40
- # TODO(Malte, 09/2025): Support choosing a subset of samples via API parameter.
41
55
  embeddings = sample_embedding_resolver.get_all_by_dataset_id(
42
56
  session=session,
43
57
  dataset_id=dataset.dataset_id,
@@ -49,6 +63,22 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
49
63
  x = embedding_values_tsne[:, 0]
50
64
  y = embedding_values_tsne[:, 1]
51
65
 
66
+ matching_sample_ids: set[UUID] | None = None
67
+ filters = body.filters if body else None
68
+ if filters:
69
+ matching_samples_result = sample_resolver.get_all_by_dataset_id(
70
+ session=session,
71
+ dataset_id=dataset.dataset_id,
72
+ filters=filters,
73
+ )
74
+ matching_sample_ids = {sample.sample_id for sample in matching_samples_result.samples}
75
+
76
+ sample_ids = [embedding.sample_id for embedding in embeddings]
77
+ if matching_sample_ids is None:
78
+ fulfils_filter = [1] * len(sample_ids)
79
+ else:
80
+ fulfils_filter = [1 if sample_id in matching_sample_ids else 0 for sample_id in sample_ids]
81
+
52
82
  # TODO(Malte, 09/2025): Save the 2D-embeddings in the database to avoid recomputing
53
83
  # them on every request.
54
84
 
@@ -57,6 +87,8 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
57
87
  {
58
88
  "x": pa.array(x, type=pa.float32()),
59
89
  "y": pa.array(y, type=pa.float32()),
90
+ "fulfils_filter": pa.array(fulfils_filter, type=pa.uint8()),
91
+ "sample_id": pa.array([str(sample_id) for sample_id in sample_ids], type=pa.string()),
60
92
  }
61
93
  )
62
94
 
@@ -5,11 +5,16 @@ from __future__ import annotations
5
5
  from typing import List
6
6
  from uuid import UUID
7
7
 
8
- from fastapi import APIRouter, Path
8
+ from fastapi import APIRouter, Depends, Path
9
+ from pydantic import BaseModel, Field
9
10
  from typing_extensions import Annotated
10
11
 
12
+ from lightly_studio.api.routes.api.dataset import get_and_validate_dataset_id
11
13
  from lightly_studio.db_manager import SessionDep
14
+ from lightly_studio.metadata import compute_typicality
15
+ from lightly_studio.models.dataset import DatasetTable
12
16
  from lightly_studio.models.metadata import MetadataInfoView
17
+ from lightly_studio.resolvers import embedding_model_resolver
13
18
  from lightly_studio.resolvers.metadata_resolver.sample.get_metadata_info import (
14
19
  get_all_metadata_keys_and_schema,
15
20
  )
@@ -33,3 +38,54 @@ def get_metadata_info(
33
38
  for numerical metadata types.
34
39
  """
35
40
  return get_all_metadata_keys_and_schema(session=session, dataset_id=dataset_id)
41
+
42
+
43
+ class ComputeTypicalityRequest(BaseModel):
44
+ """Request model for computing typicality metadata."""
45
+
46
+ embedding_model_name: str | None = Field(
47
+ default=None,
48
+ description="Embedding model name (uses default if not specified)",
49
+ )
50
+ metadata_name: str = Field(
51
+ default="typicality",
52
+ description="Metadata field name (defaults to 'typicality')",
53
+ )
54
+
55
+
56
+ @metadata_router.post(
57
+ "/metadata/typicality",
58
+ status_code=204,
59
+ response_model=None,
60
+ )
61
+ def compute_typicality_metadata(
62
+ session: SessionDep,
63
+ dataset: Annotated[
64
+ DatasetTable,
65
+ Depends(get_and_validate_dataset_id),
66
+ ],
67
+ request: ComputeTypicalityRequest,
68
+ ) -> None:
69
+ """Compute typicality metadata for a dataset.
70
+
71
+ Args:
72
+ session: The database session.
73
+ dataset: The dataset to compute typicality for.
74
+ request: Request parameters including optional embedding model name
75
+ and metadata field name.
76
+
77
+ Returns:
78
+ None (204 No Content on success).
79
+ """
80
+ embedding_model = embedding_model_resolver.get_by_name(
81
+ session=session,
82
+ dataset_id=dataset.dataset_id,
83
+ embedding_model_name=request.embedding_model_name,
84
+ )
85
+
86
+ compute_typicality.compute_typicality_metadata(
87
+ session=session,
88
+ dataset_id=dataset.dataset_id,
89
+ embedding_model_id=embedding_model.embedding_model_id,
90
+ metadata_name=request.metadata_name,
91
+ )
@@ -2,6 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
6
+ from collections import defaultdict
5
7
  from dataclasses import dataclass, field
6
8
  from pathlib import Path
7
9
  from typing import Iterable
@@ -26,10 +28,12 @@ from tqdm import tqdm
26
28
 
27
29
  from lightly_studio.models.annotation.annotation_base import AnnotationCreate
28
30
  from lightly_studio.models.annotation_label import AnnotationLabelCreate
31
+ from lightly_studio.models.caption import CaptionCreate
29
32
  from lightly_studio.models.sample import SampleCreate, SampleTable
30
33
  from lightly_studio.resolvers import (
31
34
  annotation_label_resolver,
32
35
  annotation_resolver,
36
+ caption_resolver,
33
37
  sample_resolver,
34
38
  )
35
39
 
@@ -218,6 +222,111 @@ def load_into_dataset_from_labelformat(
218
222
  return created_sample_ids
219
223
 
220
224
 
225
+ def load_into_dataset_from_coco_captions(
226
+ session: Session,
227
+ dataset_id: UUID,
228
+ annotations_json: Path,
229
+ images_path: Path,
230
+ ) -> list[UUID]:
231
+ """Load samples and captions from a COCO captions file into the dataset.
232
+
233
+ Args:
234
+ session: Database session used for resolver operations.
235
+ dataset_id: Identifier of the dataset that receives the samples.
236
+ annotations_json: Path to the COCO captions annotations file.
237
+ images_path: Directory containing the referenced images.
238
+
239
+ Returns:
240
+ The list of newly created sample identifiers.
241
+ """
242
+ with fsspec.open(str(annotations_json), "r") as file:
243
+ coco_payload = json.load(file)
244
+
245
+ images: list[dict[str, object]] = coco_payload.get("images", [])
246
+ annotations: list[dict[str, object]] = coco_payload.get("annotations", [])
247
+
248
+ captions_by_image_id: dict[int, list[str]] = defaultdict(list)
249
+ for annotation in annotations:
250
+ image_id = annotation["image_id"]
251
+ caption = annotation["caption"]
252
+ if not isinstance(image_id, int):
253
+ continue
254
+ if not isinstance(caption, str):
255
+ continue
256
+ caption_text = caption.strip()
257
+ if not caption_text:
258
+ continue
259
+ captions_by_image_id[image_id].append(caption_text)
260
+
261
+ logging_context = _LoadingLoggingContext(
262
+ n_samples_to_be_inserted=len(images),
263
+ n_samples_before_loading=sample_resolver.count_by_dataset_id(
264
+ session=session, dataset_id=dataset_id
265
+ ),
266
+ )
267
+
268
+ captions_to_create: list[CaptionCreate] = []
269
+ samples_to_create: list[SampleCreate] = []
270
+ created_sample_ids: list[UUID] = []
271
+ image_path_to_captions: dict[str, list[str]] = {}
272
+
273
+ for image_info in tqdm(images, desc="Processing images", unit=" images"):
274
+ if isinstance(image_info["id"], int):
275
+ image_id_raw = image_info["id"]
276
+ else:
277
+ continue
278
+ file_name_raw = str(image_info["file_name"])
279
+
280
+ width = image_info["width"] if isinstance(image_info["width"], int) else 0
281
+ height = image_info["height"] if isinstance(image_info["height"], int) else 0
282
+ sample = SampleCreate(
283
+ file_name=file_name_raw,
284
+ file_path_abs=str(images_path / file_name_raw),
285
+ width=width,
286
+ height=height,
287
+ dataset_id=dataset_id,
288
+ )
289
+ samples_to_create.append(sample)
290
+ image_path_to_captions[sample.file_path_abs] = captions_by_image_id.get(image_id_raw, [])
291
+
292
+ if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
293
+ created_samples_batch, paths_not_inserted = _create_batch_samples(
294
+ session=session, samples=samples_to_create
295
+ )
296
+ created_sample_ids.extend(s.sample_id for s in created_samples_batch)
297
+ logging_context.update_example_paths(paths_not_inserted)
298
+ _process_batch_captions(
299
+ session=session,
300
+ dataset_id=dataset_id,
301
+ stored_samples=created_samples_batch,
302
+ image_path_to_captions=image_path_to_captions,
303
+ captions_to_create=captions_to_create,
304
+ )
305
+ samples_to_create.clear()
306
+ image_path_to_captions.clear()
307
+
308
+ if samples_to_create:
309
+ created_samples_batch, paths_not_inserted = _create_batch_samples(
310
+ session=session, samples=samples_to_create
311
+ )
312
+ created_sample_ids.extend(s.sample_id for s in created_samples_batch)
313
+ logging_context.update_example_paths(paths_not_inserted)
314
+ _process_batch_captions(
315
+ session=session,
316
+ dataset_id=dataset_id,
317
+ stored_samples=created_samples_batch,
318
+ image_path_to_captions=image_path_to_captions,
319
+ captions_to_create=captions_to_create,
320
+ )
321
+
322
+ if captions_to_create:
323
+ caption_resolver.create_many(session=session, captions=captions_to_create)
324
+
325
+ _log_loading_results(session=session, dataset_id=dataset_id, logging_context=logging_context)
326
+
327
+ return created_sample_ids
328
+
329
+
221
330
  def _log_loading_results(
222
331
  session: Session, dataset_id: UUID, logging_context: _LoadingLoggingContext
223
332
  ) -> None:
@@ -372,3 +481,32 @@ def _process_batch_annotations( # noqa: PLR0913
372
481
  if len(annotations_to_create) >= ANNOTATION_BATCH_SIZE:
373
482
  annotation_resolver.create_many(session=session, annotations=annotations_to_create)
374
483
  annotations_to_create.clear()
484
+
485
+
486
+ def _process_batch_captions(
487
+ session: Session,
488
+ dataset_id: UUID,
489
+ stored_samples: list[SampleTable],
490
+ image_path_to_captions: dict[str, list[str]],
491
+ captions_to_create: list[CaptionCreate],
492
+ ) -> None:
493
+ """Process captions for a batch of samples."""
494
+ if not stored_samples:
495
+ return
496
+
497
+ for stored_sample in stored_samples:
498
+ captions = image_path_to_captions[stored_sample.file_path_abs]
499
+ if not captions:
500
+ continue
501
+
502
+ for caption_text in captions:
503
+ caption = CaptionCreate(
504
+ dataset_id=dataset_id,
505
+ sample_id=stored_sample.sample_id,
506
+ text=caption_text,
507
+ )
508
+ captions_to_create.append(caption)
509
+
510
+ if len(captions_to_create) >= ANNOTATION_BATCH_SIZE:
511
+ caption_resolver.create_many(session=session, captions=captions_to_create)
512
+ captions_to_create.clear()
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import Iterable, Iterator
7
7
  from uuid import UUID
8
8
 
9
+ import yaml
9
10
  from labelformat.formats import (
10
11
  COCOInstanceSegmentationInput,
11
12
  COCOObjectDetectionInput,
@@ -38,11 +39,13 @@ from lightly_studio.resolvers import (
38
39
  dataset_resolver,
39
40
  embedding_model_resolver,
40
41
  sample_resolver,
42
+ tag_resolver,
41
43
  )
42
44
  from lightly_studio.type_definitions import PathLike
43
45
 
44
46
  # Constants
45
47
  DEFAULT_DATASET_NAME = "default_dataset"
48
+ ALLOWED_YOLO_SPLITS = {"train", "val", "test", "minival"}
46
49
 
47
50
  _SliceType = slice # to avoid shadowing built-in slice in type annotations
48
51
 
@@ -68,7 +71,7 @@ class Dataset:
68
71
 
69
72
  dataset = dataset_resolver.create(
70
73
  session=db_manager.persistent_session(),
71
- dataset=DatasetCreate(name=name, directory=""),
74
+ dataset=DatasetCreate(name=name),
72
75
  )
73
76
  return Dataset(dataset=dataset)
74
77
 
@@ -262,14 +265,15 @@ class Dataset:
262
265
  def add_samples_from_yolo(
263
266
  self,
264
267
  data_yaml: PathLike,
265
- input_split: str = "train",
268
+ input_split: str | None = None,
266
269
  embed: bool = True,
267
270
  ) -> None:
268
271
  """Load a dataset in YOLO format and store in DB.
269
272
 
270
273
  Args:
271
274
  data_yaml: Path to the YOLO data.yaml file.
272
- input_split: The split to load (e.g., 'train', 'val').
275
+ input_split: The split to load (e.g., 'train', 'val', 'test').
276
+ If None, all available splits will be loaded and assigned a corresponding tag.
273
277
  embed: If True, generate embeddings for the newly added samples.
274
278
  """
275
279
  if isinstance(data_yaml, str):
@@ -279,24 +283,54 @@ class Dataset:
279
283
  if not data_yaml.is_file() or data_yaml.suffix != ".yaml":
280
284
  raise FileNotFoundError(f"YOLO data yaml file not found: '{data_yaml}'")
281
285
 
282
- # Load the dataset using labelformat.
283
- label_input = YOLOv8ObjectDetectionInput(
284
- input_file=data_yaml,
285
- input_split=input_split,
286
- )
287
- images_path = label_input._images_dir() # noqa: SLF001
286
+ # Determine which splits to process
287
+ splits_to_process = _resolve_yolo_splits(data_yaml=data_yaml, input_split=input_split)
288
288
 
289
- self.add_samples_from_labelformat(
290
- input_labels=label_input,
291
- images_path=images_path,
292
- embed=embed,
293
- )
289
+ all_created_sample_ids = []
290
+
291
+ # Process each split
292
+ for split in splits_to_process:
293
+ # Load the dataset using labelformat.
294
+ label_input = YOLOv8ObjectDetectionInput(
295
+ input_file=data_yaml,
296
+ input_split=split,
297
+ )
298
+ images_path = label_input._images_dir() # noqa: SLF001
299
+
300
+ created_sample_ids = add_samples.load_into_dataset_from_labelformat(
301
+ session=self.session,
302
+ dataset_id=self.dataset_id,
303
+ input_labels=label_input,
304
+ images_path=images_path,
305
+ )
306
+
307
+ # Tag samples with split name
308
+ if created_sample_ids:
309
+ tag = tag_resolver.get_or_create_sample_tag_by_name(
310
+ session=self.session,
311
+ dataset_id=self.dataset_id,
312
+ tag_name=split,
313
+ )
314
+ tag_resolver.add_sample_ids_to_tag_id(
315
+ session=self.session,
316
+ tag_id=tag.tag_id,
317
+ sample_ids=created_sample_ids,
318
+ )
319
+
320
+ all_created_sample_ids.extend(created_sample_ids)
321
+
322
+ # Generate embeddings for all samples at once
323
+ if embed:
324
+ _generate_embeddings(
325
+ session=self.session, dataset_id=self.dataset_id, sample_ids=all_created_sample_ids
326
+ )
294
327
 
295
328
  def add_samples_from_coco(
296
329
  self,
297
330
  annotations_json: PathLike,
298
331
  images_path: PathLike,
299
332
  annotation_type: AnnotationType = AnnotationType.OBJECT_DETECTION,
333
+ split: str | None = None,
300
334
  embed: bool = True,
301
335
  ) -> None:
302
336
  """Load a dataset in COCO Object Detection format and store in DB.
@@ -306,6 +340,8 @@ class Dataset:
306
340
  images_path: Path to the folder containing the images.
307
341
  annotation_type: The type of annotation to be loaded (e.g., 'ObjectDetection',
308
342
  'InstanceSegmentation').
343
+ split: Optional split name to tag samples (e.g., 'train', 'val').
344
+ If provided, all samples will be tagged with this name.
309
345
  embed: If True, generate embeddings for the newly added samples.
310
346
  """
311
347
  if isinstance(annotations_json, str):
@@ -330,12 +366,83 @@ class Dataset:
330
366
 
331
367
  images_path = Path(images_path).absolute()
332
368
 
333
- self.add_samples_from_labelformat(
369
+ created_sample_ids = add_samples.load_into_dataset_from_labelformat(
370
+ session=self.session,
371
+ dataset_id=self.dataset_id,
334
372
  input_labels=label_input,
335
373
  images_path=images_path,
336
- embed=embed,
337
374
  )
338
375
 
376
+ # Tag samples with split name if provided
377
+ if split is not None and created_sample_ids:
378
+ tag = tag_resolver.get_or_create_sample_tag_by_name(
379
+ session=self.session,
380
+ dataset_id=self.dataset_id,
381
+ tag_name=split,
382
+ )
383
+ tag_resolver.add_sample_ids_to_tag_id(
384
+ session=self.session,
385
+ tag_id=tag.tag_id,
386
+ sample_ids=created_sample_ids,
387
+ )
388
+
389
+ if embed:
390
+ _generate_embeddings(
391
+ session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
392
+ )
393
+
394
+ def add_samples_from_coco_caption(
395
+ self,
396
+ annotations_json: PathLike,
397
+ images_path: PathLike,
398
+ split: str | None = None,
399
+ embed: bool = True,
400
+ ) -> None:
401
+ """Load a dataset in COCO caption format and store in DB.
402
+
403
+ Args:
404
+ annotations_json: Path to the COCO caption JSON file.
405
+ images_path: Path to the folder containing the images.
406
+ split: Optional split name to tag samples (e.g., 'train', 'val').
407
+ If provided, all samples will be tagged with this name.
408
+ embed: If True, generate embeddings for the newly added samples.
409
+ """
410
+ if isinstance(annotations_json, str):
411
+ annotations_json = Path(annotations_json)
412
+ annotations_json = annotations_json.absolute()
413
+
414
+ if not annotations_json.is_file() or annotations_json.suffix != ".json":
415
+ raise FileNotFoundError(f"COCO caption json file not found: '{annotations_json}'")
416
+
417
+ if isinstance(images_path, str):
418
+ images_path = Path(images_path)
419
+ images_path = images_path.absolute()
420
+
421
+ created_sample_ids = add_samples.load_into_dataset_from_coco_captions(
422
+ session=self.session,
423
+ dataset_id=self.dataset_id,
424
+ annotations_json=annotations_json,
425
+ images_path=images_path,
426
+ )
427
+
428
+ # Tag samples with split name if provided
429
+ if split is not None and created_sample_ids:
430
+ tag = tag_resolver.get_or_create_sample_tag_by_name(
431
+ session=self.session,
432
+ dataset_id=self.dataset_id,
433
+ tag_name=split,
434
+ )
435
+ tag_resolver.add_sample_ids_to_tag_id(
436
+ session=self.session,
437
+ tag_id=tag.tag_id,
438
+ sample_ids=created_sample_ids,
439
+ )
440
+
441
+ if embed:
442
+ _generate_embeddings(
443
+ session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
444
+ )
445
+
339
446
  def compute_typicality_metadata(
340
447
  self,
341
448
  embedding_model_name: str | None = None,
@@ -393,3 +500,23 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU
393
500
  # Mark the embedding search feature as enabled.
394
501
  if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
395
502
  features.lightly_studio_active_features.append("embeddingSearchEnabled")
503
+
504
+
505
+ def _resolve_yolo_splits(data_yaml: Path, input_split: str | None) -> list[str]:
506
+ """Determine which YOLO splits to process for the given config."""
507
+ if input_split is not None:
508
+ if input_split not in ALLOWED_YOLO_SPLITS:
509
+ raise ValueError(
510
+ f"Split '{input_split}' not found in config file '{data_yaml}'. "
511
+ f"Allowed splits: {sorted(ALLOWED_YOLO_SPLITS)}"
512
+ )
513
+ return [input_split]
514
+
515
+ with data_yaml.open() as f:
516
+ config = yaml.safe_load(f)
517
+
518
+ config_keys = config.keys() if isinstance(config, dict) else []
519
+ splits = [key for key in config_keys if key in ALLOWED_YOLO_SPLITS]
520
+ if not splits:
521
+ raise ValueError(f"No splits found in config file '{data_yaml}'")
522
+ return splits
@@ -258,10 +258,7 @@ class DatasetLoader:
258
258
  # Create dataset and annotation task.
259
259
  dataset = dataset_resolver.create(
260
260
  session=self.session,
261
- dataset=DatasetCreate(
262
- name=dataset_name,
263
- directory=str(img_dir_path),
264
- ),
261
+ dataset=DatasetCreate(name=dataset_name),
265
262
  )
266
263
 
267
264
  self._load_into_dataset(
@@ -296,10 +293,7 @@ class DatasetLoader:
296
293
  # Create dataset.
297
294
  dataset = dataset_resolver.create(
298
295
  session=self.session,
299
- dataset=DatasetCreate(
300
- name=dataset_name,
301
- directory=img_dir,
302
- ),
296
+ dataset=DatasetCreate(name=dataset_name),
303
297
  )
304
298
 
305
299
  # Collect image file paths with extension filtering.
@@ -57,6 +57,11 @@ class DatabaseEngine:
57
57
  try:
58
58
  yield session
59
59
  session.commit()
60
+
61
+ # Commit the persistent session to ensure it sees the latest data changes.
62
+ # This prevents the persistent session from having stale data when it's used
63
+ # after operations in short-lived sessions have modified the database.
64
+ self.get_persistent_session().commit()
60
65
  except Exception:
61
66
  session.rollback()
62
67
  raise
@@ -66,7 +71,9 @@ class DatabaseEngine:
66
71
  def get_persistent_session(self) -> Session:
67
72
  """Get the persistent database session."""
68
73
  if self._persistent_session is None:
69
- self._persistent_session = Session(self._engine, close_resets_only=False)
74
+ self._persistent_session = Session(
75
+ self._engine, close_resets_only=False, expire_on_commit=True
76
+ )
70
77
  return self._persistent_session
71
78
 
72
79
 
@@ -78,11 +85,10 @@ def get_engine() -> DatabaseEngine:
78
85
  """Get the database engine.
79
86
 
80
87
  If the engine does not exist yet, it is newly created with the default settings.
81
- In that case, a pre-existing database file is deleted.
82
88
  """
83
89
  global _engine # noqa: PLW0603
84
90
  if _engine is None:
85
- _engine = DatabaseEngine(cleanup_existing=True)
91
+ _engine = DatabaseEngine()
86
92
  return _engine
87
93
 
88
94
 
@@ -94,7 +100,7 @@ def set_engine(engine: DatabaseEngine) -> None:
94
100
  _engine = engine
95
101
 
96
102
 
97
- def connect(db_file: str | None, cleanup_existing: bool = False) -> None:
103
+ def connect(db_file: str | None = None, cleanup_existing: bool = False) -> None:
98
104
  """Set up the database connection.
99
105
 
100
106
  Helper function to set up the database engine.