lightly-studio 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/api/app.py +2 -0
- lightly_studio/api/routes/api/caption.py +30 -0
- lightly_studio/api/routes/api/embeddings2d.py +36 -4
- lightly_studio/api/routes/api/metadata.py +57 -1
- lightly_studio/core/add_samples.py +138 -0
- lightly_studio/core/dataset.py +143 -16
- lightly_studio/dataset/loader.py +2 -8
- lightly_studio/db_manager.py +10 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.B3oFNb6O.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/Samples.CIbricz7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.7Ma7YdVg.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{useFeatureFlags.CV-KWLNP.css → _layout.CefECEWA.css} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.2jKMtOWG.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/-DXuGN29.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Cs1XmhiF.js → B7302SU7.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BeWf8-vJ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bqz7dyEC.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BdfTHw61.js → CSCQddQS.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CZGpyrcA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CfQ4mGwl.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CiaNZCBa.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cqo0Vpvt.js +417 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cy4fgWTG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5w4xp5l.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DD63uD-T.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQ8aZ1o-.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{keKYsoph.js → DSxvnAMh.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_JuJOO3.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_ynJAfY.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dafy4oEQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BfHVnyNT.js → Dj4O-5se.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmjAI-UV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dug7Bq1S.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Dv5BSBQG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzBTnFhV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DzX_yyqb.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Frwd2CjB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H4l0JFh9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H60ATh8g.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{6t3IJ0vQ.js → qIv1kPyv.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/sLqs1uaK.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/u-it74zV.js +96 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BPc0HQPq.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.SNvc2nrm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.5jT7P06o.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.Cdy-7S5q.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.C_uoESTX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DcO8wIAc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{2.C8HLK8mj.js → 2.BIldfkxL.js} +268 -113
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CLvg3QcJ.js → 3.BC9z_TWM.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.BQhDtXUI.js → 4.D8X_Ch5n.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.CAXhxJu6.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{6.uBV1Lhat.js → 6.DRA5Ru_2.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.WVBsruHQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BuKUrCEN.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CUIn1yCR.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
- lightly_studio/dist_lightly_studio_view_app/index.html +15 -14
- lightly_studio/examples/example.py +4 -0
- lightly_studio/examples/example_coco.py +4 -0
- lightly_studio/examples/example_coco_caption.py +24 -0
- lightly_studio/examples/example_metadata.py +4 -1
- lightly_studio/examples/example_selection.py +4 -0
- lightly_studio/examples/example_split_work.py +4 -0
- lightly_studio/examples/example_yolo.py +4 -0
- lightly_studio/export/export_dataset.py +11 -3
- lightly_studio/metadata/compute_typicality.py +1 -1
- lightly_studio/models/caption.py +73 -0
- lightly_studio/models/dataset.py +1 -2
- lightly_studio/models/metadata.py +1 -1
- lightly_studio/models/sample.py +2 -2
- lightly_studio/resolvers/caption_resolver.py +80 -0
- lightly_studio/resolvers/dataset_resolver.py +4 -7
- lightly_studio/resolvers/metadata_resolver/__init__.py +2 -2
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +3 -3
- lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
- lightly_studio/resolvers/samples_filter.py +18 -10
- lightly_studio/type_definitions.py +2 -0
- {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/METADATA +86 -21
- {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/RECORD +83 -77
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.CA_CXIBb.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.DS78jgNY.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/index.BVs_sZj9.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.D487hwJk.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/8NsknIT2.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BND_-4Kp.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BjkP1AHA.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BuuNVL9G.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BzKGpnl4.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CCx7Ho51.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CH6P3X75.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CR2upx_Q.js +0 -4
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWPZrTTJ.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CwPowJfP.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxFKfZ9T.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cxevwdid.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4whDBUi.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6r9vr07.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DA6bFLPR.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DEgUu98i.js +0 -3
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DGTPl6Gk.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DKGxBSlK.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQXoLcsF.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQe_kdRt.js +0 -92
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DcY4jgG3.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/RmD8FzRo.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/V-MnMC1X.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BVr6DYqP.js +0 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.u7zsVvqp.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.Da2agmdd.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B11tVRJV.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.l30Zud4h.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CgKPGcAP.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.-6XqWX5G.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.BXsgoQZh.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BkbcnUs8.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.Bkrv-Vww.js +0 -1
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +0 -48
- {lightly_studio-0.3.3.dist-info → lightly_studio-0.3.4.dist-info}/WHEEL +0 -0
lightly_studio/api/app.py
CHANGED
|
@@ -16,6 +16,7 @@ from lightly_studio.api.routes import healthz, images, webapp
|
|
|
16
16
|
from lightly_studio.api.routes.api import (
|
|
17
17
|
annotation,
|
|
18
18
|
annotation_label,
|
|
19
|
+
caption,
|
|
19
20
|
classifier,
|
|
20
21
|
dataset,
|
|
21
22
|
dataset_tag,
|
|
@@ -89,6 +90,7 @@ api_router.include_router(export.export_router)
|
|
|
89
90
|
api_router.include_router(sample.samples_router)
|
|
90
91
|
api_router.include_router(annotation_label.annotations_label_router)
|
|
91
92
|
api_router.include_router(annotation.annotations_router)
|
|
93
|
+
api_router.include_router(caption.captions_router)
|
|
92
94
|
api_router.include_router(text_embedding.text_embedding_router)
|
|
93
95
|
api_router.include_router(settings.settings_router)
|
|
94
96
|
api_router.include_router(classifier.classifier_router)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""API routes for dataset captions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
from fastapi import APIRouter, Depends, Path
|
|
8
|
+
from typing_extensions import Annotated
|
|
9
|
+
|
|
10
|
+
from lightly_studio.api.routes.api.validators import Paginated, PaginatedWithCursor
|
|
11
|
+
from lightly_studio.db_manager import SessionDep
|
|
12
|
+
from lightly_studio.models.caption import CaptionsListView
|
|
13
|
+
from lightly_studio.resolvers import caption_resolver
|
|
14
|
+
from lightly_studio.resolvers.caption_resolver import GetAllCaptionsResult
|
|
15
|
+
|
|
16
|
+
captions_router = APIRouter(prefix="/datasets/{dataset_id}", tags=["captions"])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@captions_router.get("/captions", response_model=CaptionsListView)
|
|
20
|
+
def read_captions(
|
|
21
|
+
dataset_id: Annotated[UUID, Path(title="Dataset Id")],
|
|
22
|
+
session: SessionDep,
|
|
23
|
+
pagination: Annotated[PaginatedWithCursor, Depends()],
|
|
24
|
+
) -> GetAllCaptionsResult:
|
|
25
|
+
"""Retrieve captions for a dataset."""
|
|
26
|
+
return caption_resolver.get_all(
|
|
27
|
+
session=session,
|
|
28
|
+
dataset_id=dataset_id,
|
|
29
|
+
pagination=Paginated(offset=pagination.offset, limit=pagination.limit),
|
|
30
|
+
)
|
|
@@ -3,25 +3,40 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import io
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import pyarrow as pa
|
|
9
10
|
from fastapi import APIRouter, HTTPException, Response
|
|
10
11
|
from numpy.typing import NDArray
|
|
11
12
|
from pyarrow import ipc
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
12
14
|
from sklearn.manifold import TSNE
|
|
13
15
|
from sqlmodel import select
|
|
14
16
|
|
|
15
17
|
from lightly_studio.db_manager import SessionDep
|
|
16
18
|
from lightly_studio.models.dataset import DatasetTable
|
|
17
19
|
from lightly_studio.models.embedding_model import EmbeddingModelTable
|
|
18
|
-
from lightly_studio.resolvers import sample_embedding_resolver
|
|
20
|
+
from lightly_studio.resolvers import sample_embedding_resolver, sample_resolver
|
|
21
|
+
from lightly_studio.resolvers.samples_filter import SampleFilter
|
|
19
22
|
|
|
20
23
|
embeddings2d_router = APIRouter()
|
|
21
24
|
|
|
22
25
|
|
|
23
|
-
|
|
24
|
-
|
|
26
|
+
class GetEmbeddings2DRequest(BaseModel):
|
|
27
|
+
"""Request body for retrieving 2D embeddings."""
|
|
28
|
+
|
|
29
|
+
filters: SampleFilter | None = Field(
|
|
30
|
+
None,
|
|
31
|
+
description="Filter parameters identifying matching samples",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@embeddings2d_router.post("/embeddings2d/tsne")
|
|
36
|
+
def get_embeddings2d__tsne(
|
|
37
|
+
session: SessionDep,
|
|
38
|
+
body: GetEmbeddings2DRequest | None = None,
|
|
39
|
+
) -> Response:
|
|
25
40
|
"""Return 2D embeddings serialized as an Arrow stream."""
|
|
26
41
|
# TODO(Malte, 09/2025): Support choosing the dataset via API parameter.
|
|
27
42
|
dataset = session.exec(select(DatasetTable).limit(1)).first()
|
|
@@ -37,7 +52,6 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
|
|
|
37
52
|
if embedding_model is None:
|
|
38
53
|
raise HTTPException(status_code=404, detail="No embedding model configured.")
|
|
39
54
|
|
|
40
|
-
# TODO(Malte, 09/2025): Support choosing a subset of samples via API parameter.
|
|
41
55
|
embeddings = sample_embedding_resolver.get_all_by_dataset_id(
|
|
42
56
|
session=session,
|
|
43
57
|
dataset_id=dataset.dataset_id,
|
|
@@ -49,6 +63,22 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
|
|
|
49
63
|
x = embedding_values_tsne[:, 0]
|
|
50
64
|
y = embedding_values_tsne[:, 1]
|
|
51
65
|
|
|
66
|
+
matching_sample_ids: set[UUID] | None = None
|
|
67
|
+
filters = body.filters if body else None
|
|
68
|
+
if filters:
|
|
69
|
+
matching_samples_result = sample_resolver.get_all_by_dataset_id(
|
|
70
|
+
session=session,
|
|
71
|
+
dataset_id=dataset.dataset_id,
|
|
72
|
+
filters=filters,
|
|
73
|
+
)
|
|
74
|
+
matching_sample_ids = {sample.sample_id for sample in matching_samples_result.samples}
|
|
75
|
+
|
|
76
|
+
sample_ids = [embedding.sample_id for embedding in embeddings]
|
|
77
|
+
if matching_sample_ids is None:
|
|
78
|
+
fulfils_filter = [1] * len(sample_ids)
|
|
79
|
+
else:
|
|
80
|
+
fulfils_filter = [1 if sample_id in matching_sample_ids else 0 for sample_id in sample_ids]
|
|
81
|
+
|
|
52
82
|
# TODO(Malte, 09/2025): Save the 2D-embeddings in the database to avoid recomputing
|
|
53
83
|
# them on every request.
|
|
54
84
|
|
|
@@ -57,6 +87,8 @@ def get_embeddings2d__tsne(session: SessionDep) -> Response:
|
|
|
57
87
|
{
|
|
58
88
|
"x": pa.array(x, type=pa.float32()),
|
|
59
89
|
"y": pa.array(y, type=pa.float32()),
|
|
90
|
+
"fulfils_filter": pa.array(fulfils_filter, type=pa.uint8()),
|
|
91
|
+
"sample_id": pa.array([str(sample_id) for sample_id in sample_ids], type=pa.string()),
|
|
60
92
|
}
|
|
61
93
|
)
|
|
62
94
|
|
|
@@ -5,11 +5,16 @@ from __future__ import annotations
|
|
|
5
5
|
from typing import List
|
|
6
6
|
from uuid import UUID
|
|
7
7
|
|
|
8
|
-
from fastapi import APIRouter, Path
|
|
8
|
+
from fastapi import APIRouter, Depends, Path
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
9
10
|
from typing_extensions import Annotated
|
|
10
11
|
|
|
12
|
+
from lightly_studio.api.routes.api.dataset import get_and_validate_dataset_id
|
|
11
13
|
from lightly_studio.db_manager import SessionDep
|
|
14
|
+
from lightly_studio.metadata import compute_typicality
|
|
15
|
+
from lightly_studio.models.dataset import DatasetTable
|
|
12
16
|
from lightly_studio.models.metadata import MetadataInfoView
|
|
17
|
+
from lightly_studio.resolvers import embedding_model_resolver
|
|
13
18
|
from lightly_studio.resolvers.metadata_resolver.sample.get_metadata_info import (
|
|
14
19
|
get_all_metadata_keys_and_schema,
|
|
15
20
|
)
|
|
@@ -33,3 +38,54 @@ def get_metadata_info(
|
|
|
33
38
|
for numerical metadata types.
|
|
34
39
|
"""
|
|
35
40
|
return get_all_metadata_keys_and_schema(session=session, dataset_id=dataset_id)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ComputeTypicalityRequest(BaseModel):
|
|
44
|
+
"""Request model for computing typicality metadata."""
|
|
45
|
+
|
|
46
|
+
embedding_model_name: str | None = Field(
|
|
47
|
+
default=None,
|
|
48
|
+
description="Embedding model name (uses default if not specified)",
|
|
49
|
+
)
|
|
50
|
+
metadata_name: str = Field(
|
|
51
|
+
default="typicality",
|
|
52
|
+
description="Metadata field name (defaults to 'typicality')",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@metadata_router.post(
|
|
57
|
+
"/metadata/typicality",
|
|
58
|
+
status_code=204,
|
|
59
|
+
response_model=None,
|
|
60
|
+
)
|
|
61
|
+
def compute_typicality_metadata(
|
|
62
|
+
session: SessionDep,
|
|
63
|
+
dataset: Annotated[
|
|
64
|
+
DatasetTable,
|
|
65
|
+
Depends(get_and_validate_dataset_id),
|
|
66
|
+
],
|
|
67
|
+
request: ComputeTypicalityRequest,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Compute typicality metadata for a dataset.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
session: The database session.
|
|
73
|
+
dataset: The dataset to compute typicality for.
|
|
74
|
+
request: Request parameters including optional embedding model name
|
|
75
|
+
and metadata field name.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
None (204 No Content on success).
|
|
79
|
+
"""
|
|
80
|
+
embedding_model = embedding_model_resolver.get_by_name(
|
|
81
|
+
session=session,
|
|
82
|
+
dataset_id=dataset.dataset_id,
|
|
83
|
+
embedding_model_name=request.embedding_model_name,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
compute_typicality.compute_typicality_metadata(
|
|
87
|
+
session=session,
|
|
88
|
+
dataset_id=dataset.dataset_id,
|
|
89
|
+
embedding_model_id=embedding_model.embedding_model_id,
|
|
90
|
+
metadata_name=request.metadata_name,
|
|
91
|
+
)
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
6
|
+
from collections import defaultdict
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
from typing import Iterable
|
|
@@ -26,10 +28,12 @@ from tqdm import tqdm
|
|
|
26
28
|
|
|
27
29
|
from lightly_studio.models.annotation.annotation_base import AnnotationCreate
|
|
28
30
|
from lightly_studio.models.annotation_label import AnnotationLabelCreate
|
|
31
|
+
from lightly_studio.models.caption import CaptionCreate
|
|
29
32
|
from lightly_studio.models.sample import SampleCreate, SampleTable
|
|
30
33
|
from lightly_studio.resolvers import (
|
|
31
34
|
annotation_label_resolver,
|
|
32
35
|
annotation_resolver,
|
|
36
|
+
caption_resolver,
|
|
33
37
|
sample_resolver,
|
|
34
38
|
)
|
|
35
39
|
|
|
@@ -218,6 +222,111 @@ def load_into_dataset_from_labelformat(
|
|
|
218
222
|
return created_sample_ids
|
|
219
223
|
|
|
220
224
|
|
|
225
|
+
def load_into_dataset_from_coco_captions(
|
|
226
|
+
session: Session,
|
|
227
|
+
dataset_id: UUID,
|
|
228
|
+
annotations_json: Path,
|
|
229
|
+
images_path: Path,
|
|
230
|
+
) -> list[UUID]:
|
|
231
|
+
"""Load samples and captions from a COCO captions file into the dataset.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
session: Database session used for resolver operations.
|
|
235
|
+
dataset_id: Identifier of the dataset that receives the samples.
|
|
236
|
+
annotations_json: Path to the COCO captions annotations file.
|
|
237
|
+
images_path: Directory containing the referenced images.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
The list of newly created sample identifiers.
|
|
241
|
+
"""
|
|
242
|
+
with fsspec.open(str(annotations_json), "r") as file:
|
|
243
|
+
coco_payload = json.load(file)
|
|
244
|
+
|
|
245
|
+
images: list[dict[str, object]] = coco_payload.get("images", [])
|
|
246
|
+
annotations: list[dict[str, object]] = coco_payload.get("annotations", [])
|
|
247
|
+
|
|
248
|
+
captions_by_image_id: dict[int, list[str]] = defaultdict(list)
|
|
249
|
+
for annotation in annotations:
|
|
250
|
+
image_id = annotation["image_id"]
|
|
251
|
+
caption = annotation["caption"]
|
|
252
|
+
if not isinstance(image_id, int):
|
|
253
|
+
continue
|
|
254
|
+
if not isinstance(caption, str):
|
|
255
|
+
continue
|
|
256
|
+
caption_text = caption.strip()
|
|
257
|
+
if not caption_text:
|
|
258
|
+
continue
|
|
259
|
+
captions_by_image_id[image_id].append(caption_text)
|
|
260
|
+
|
|
261
|
+
logging_context = _LoadingLoggingContext(
|
|
262
|
+
n_samples_to_be_inserted=len(images),
|
|
263
|
+
n_samples_before_loading=sample_resolver.count_by_dataset_id(
|
|
264
|
+
session=session, dataset_id=dataset_id
|
|
265
|
+
),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
captions_to_create: list[CaptionCreate] = []
|
|
269
|
+
samples_to_create: list[SampleCreate] = []
|
|
270
|
+
created_sample_ids: list[UUID] = []
|
|
271
|
+
image_path_to_captions: dict[str, list[str]] = {}
|
|
272
|
+
|
|
273
|
+
for image_info in tqdm(images, desc="Processing images", unit=" images"):
|
|
274
|
+
if isinstance(image_info["id"], int):
|
|
275
|
+
image_id_raw = image_info["id"]
|
|
276
|
+
else:
|
|
277
|
+
continue
|
|
278
|
+
file_name_raw = str(image_info["file_name"])
|
|
279
|
+
|
|
280
|
+
width = image_info["width"] if isinstance(image_info["width"], int) else 0
|
|
281
|
+
height = image_info["height"] if isinstance(image_info["height"], int) else 0
|
|
282
|
+
sample = SampleCreate(
|
|
283
|
+
file_name=file_name_raw,
|
|
284
|
+
file_path_abs=str(images_path / file_name_raw),
|
|
285
|
+
width=width,
|
|
286
|
+
height=height,
|
|
287
|
+
dataset_id=dataset_id,
|
|
288
|
+
)
|
|
289
|
+
samples_to_create.append(sample)
|
|
290
|
+
image_path_to_captions[sample.file_path_abs] = captions_by_image_id.get(image_id_raw, [])
|
|
291
|
+
|
|
292
|
+
if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
|
|
293
|
+
created_samples_batch, paths_not_inserted = _create_batch_samples(
|
|
294
|
+
session=session, samples=samples_to_create
|
|
295
|
+
)
|
|
296
|
+
created_sample_ids.extend(s.sample_id for s in created_samples_batch)
|
|
297
|
+
logging_context.update_example_paths(paths_not_inserted)
|
|
298
|
+
_process_batch_captions(
|
|
299
|
+
session=session,
|
|
300
|
+
dataset_id=dataset_id,
|
|
301
|
+
stored_samples=created_samples_batch,
|
|
302
|
+
image_path_to_captions=image_path_to_captions,
|
|
303
|
+
captions_to_create=captions_to_create,
|
|
304
|
+
)
|
|
305
|
+
samples_to_create.clear()
|
|
306
|
+
image_path_to_captions.clear()
|
|
307
|
+
|
|
308
|
+
if samples_to_create:
|
|
309
|
+
created_samples_batch, paths_not_inserted = _create_batch_samples(
|
|
310
|
+
session=session, samples=samples_to_create
|
|
311
|
+
)
|
|
312
|
+
created_sample_ids.extend(s.sample_id for s in created_samples_batch)
|
|
313
|
+
logging_context.update_example_paths(paths_not_inserted)
|
|
314
|
+
_process_batch_captions(
|
|
315
|
+
session=session,
|
|
316
|
+
dataset_id=dataset_id,
|
|
317
|
+
stored_samples=created_samples_batch,
|
|
318
|
+
image_path_to_captions=image_path_to_captions,
|
|
319
|
+
captions_to_create=captions_to_create,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if captions_to_create:
|
|
323
|
+
caption_resolver.create_many(session=session, captions=captions_to_create)
|
|
324
|
+
|
|
325
|
+
_log_loading_results(session=session, dataset_id=dataset_id, logging_context=logging_context)
|
|
326
|
+
|
|
327
|
+
return created_sample_ids
|
|
328
|
+
|
|
329
|
+
|
|
221
330
|
def _log_loading_results(
|
|
222
331
|
session: Session, dataset_id: UUID, logging_context: _LoadingLoggingContext
|
|
223
332
|
) -> None:
|
|
@@ -372,3 +481,32 @@ def _process_batch_annotations( # noqa: PLR0913
|
|
|
372
481
|
if len(annotations_to_create) >= ANNOTATION_BATCH_SIZE:
|
|
373
482
|
annotation_resolver.create_many(session=session, annotations=annotations_to_create)
|
|
374
483
|
annotations_to_create.clear()
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _process_batch_captions(
|
|
487
|
+
session: Session,
|
|
488
|
+
dataset_id: UUID,
|
|
489
|
+
stored_samples: list[SampleTable],
|
|
490
|
+
image_path_to_captions: dict[str, list[str]],
|
|
491
|
+
captions_to_create: list[CaptionCreate],
|
|
492
|
+
) -> None:
|
|
493
|
+
"""Process captions for a batch of samples."""
|
|
494
|
+
if not stored_samples:
|
|
495
|
+
return
|
|
496
|
+
|
|
497
|
+
for stored_sample in stored_samples:
|
|
498
|
+
captions = image_path_to_captions[stored_sample.file_path_abs]
|
|
499
|
+
if not captions:
|
|
500
|
+
continue
|
|
501
|
+
|
|
502
|
+
for caption_text in captions:
|
|
503
|
+
caption = CaptionCreate(
|
|
504
|
+
dataset_id=dataset_id,
|
|
505
|
+
sample_id=stored_sample.sample_id,
|
|
506
|
+
text=caption_text,
|
|
507
|
+
)
|
|
508
|
+
captions_to_create.append(caption)
|
|
509
|
+
|
|
510
|
+
if len(captions_to_create) >= ANNOTATION_BATCH_SIZE:
|
|
511
|
+
caption_resolver.create_many(session=session, captions=captions_to_create)
|
|
512
|
+
captions_to_create.clear()
|
lightly_studio/core/dataset.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Iterable, Iterator
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
|
+
import yaml
|
|
9
10
|
from labelformat.formats import (
|
|
10
11
|
COCOInstanceSegmentationInput,
|
|
11
12
|
COCOObjectDetectionInput,
|
|
@@ -38,11 +39,13 @@ from lightly_studio.resolvers import (
|
|
|
38
39
|
dataset_resolver,
|
|
39
40
|
embedding_model_resolver,
|
|
40
41
|
sample_resolver,
|
|
42
|
+
tag_resolver,
|
|
41
43
|
)
|
|
42
44
|
from lightly_studio.type_definitions import PathLike
|
|
43
45
|
|
|
44
46
|
# Constants
|
|
45
47
|
DEFAULT_DATASET_NAME = "default_dataset"
|
|
48
|
+
ALLOWED_YOLO_SPLITS = {"train", "val", "test", "minival"}
|
|
46
49
|
|
|
47
50
|
_SliceType = slice # to avoid shadowing built-in slice in type annotations
|
|
48
51
|
|
|
@@ -68,7 +71,7 @@ class Dataset:
|
|
|
68
71
|
|
|
69
72
|
dataset = dataset_resolver.create(
|
|
70
73
|
session=db_manager.persistent_session(),
|
|
71
|
-
dataset=DatasetCreate(name=name
|
|
74
|
+
dataset=DatasetCreate(name=name),
|
|
72
75
|
)
|
|
73
76
|
return Dataset(dataset=dataset)
|
|
74
77
|
|
|
@@ -262,14 +265,15 @@ class Dataset:
|
|
|
262
265
|
def add_samples_from_yolo(
|
|
263
266
|
self,
|
|
264
267
|
data_yaml: PathLike,
|
|
265
|
-
input_split: str =
|
|
268
|
+
input_split: str | None = None,
|
|
266
269
|
embed: bool = True,
|
|
267
270
|
) -> None:
|
|
268
271
|
"""Load a dataset in YOLO format and store in DB.
|
|
269
272
|
|
|
270
273
|
Args:
|
|
271
274
|
data_yaml: Path to the YOLO data.yaml file.
|
|
272
|
-
input_split: The split to load (e.g., 'train', 'val').
|
|
275
|
+
input_split: The split to load (e.g., 'train', 'val', 'test').
|
|
276
|
+
If None, all available splits will be loaded and assigned a corresponding tag.
|
|
273
277
|
embed: If True, generate embeddings for the newly added samples.
|
|
274
278
|
"""
|
|
275
279
|
if isinstance(data_yaml, str):
|
|
@@ -279,24 +283,54 @@ class Dataset:
|
|
|
279
283
|
if not data_yaml.is_file() or data_yaml.suffix != ".yaml":
|
|
280
284
|
raise FileNotFoundError(f"YOLO data yaml file not found: '{data_yaml}'")
|
|
281
285
|
|
|
282
|
-
#
|
|
283
|
-
|
|
284
|
-
input_file=data_yaml,
|
|
285
|
-
input_split=input_split,
|
|
286
|
-
)
|
|
287
|
-
images_path = label_input._images_dir() # noqa: SLF001
|
|
286
|
+
# Determine which splits to process
|
|
287
|
+
splits_to_process = _resolve_yolo_splits(data_yaml=data_yaml, input_split=input_split)
|
|
288
288
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
289
|
+
all_created_sample_ids = []
|
|
290
|
+
|
|
291
|
+
# Process each split
|
|
292
|
+
for split in splits_to_process:
|
|
293
|
+
# Load the dataset using labelformat.
|
|
294
|
+
label_input = YOLOv8ObjectDetectionInput(
|
|
295
|
+
input_file=data_yaml,
|
|
296
|
+
input_split=split,
|
|
297
|
+
)
|
|
298
|
+
images_path = label_input._images_dir() # noqa: SLF001
|
|
299
|
+
|
|
300
|
+
created_sample_ids = add_samples.load_into_dataset_from_labelformat(
|
|
301
|
+
session=self.session,
|
|
302
|
+
dataset_id=self.dataset_id,
|
|
303
|
+
input_labels=label_input,
|
|
304
|
+
images_path=images_path,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Tag samples with split name
|
|
308
|
+
if created_sample_ids:
|
|
309
|
+
tag = tag_resolver.get_or_create_sample_tag_by_name(
|
|
310
|
+
session=self.session,
|
|
311
|
+
dataset_id=self.dataset_id,
|
|
312
|
+
tag_name=split,
|
|
313
|
+
)
|
|
314
|
+
tag_resolver.add_sample_ids_to_tag_id(
|
|
315
|
+
session=self.session,
|
|
316
|
+
tag_id=tag.tag_id,
|
|
317
|
+
sample_ids=created_sample_ids,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
all_created_sample_ids.extend(created_sample_ids)
|
|
321
|
+
|
|
322
|
+
# Generate embeddings for all samples at once
|
|
323
|
+
if embed:
|
|
324
|
+
_generate_embeddings(
|
|
325
|
+
session=self.session, dataset_id=self.dataset_id, sample_ids=all_created_sample_ids
|
|
326
|
+
)
|
|
294
327
|
|
|
295
328
|
def add_samples_from_coco(
|
|
296
329
|
self,
|
|
297
330
|
annotations_json: PathLike,
|
|
298
331
|
images_path: PathLike,
|
|
299
332
|
annotation_type: AnnotationType = AnnotationType.OBJECT_DETECTION,
|
|
333
|
+
split: str | None = None,
|
|
300
334
|
embed: bool = True,
|
|
301
335
|
) -> None:
|
|
302
336
|
"""Load a dataset in COCO Object Detection format and store in DB.
|
|
@@ -306,6 +340,8 @@ class Dataset:
|
|
|
306
340
|
images_path: Path to the folder containing the images.
|
|
307
341
|
annotation_type: The type of annotation to be loaded (e.g., 'ObjectDetection',
|
|
308
342
|
'InstanceSegmentation').
|
|
343
|
+
split: Optional split name to tag samples (e.g., 'train', 'val').
|
|
344
|
+
If provided, all samples will be tagged with this name.
|
|
309
345
|
embed: If True, generate embeddings for the newly added samples.
|
|
310
346
|
"""
|
|
311
347
|
if isinstance(annotations_json, str):
|
|
@@ -330,12 +366,83 @@ class Dataset:
|
|
|
330
366
|
|
|
331
367
|
images_path = Path(images_path).absolute()
|
|
332
368
|
|
|
333
|
-
|
|
369
|
+
created_sample_ids = add_samples.load_into_dataset_from_labelformat(
|
|
370
|
+
session=self.session,
|
|
371
|
+
dataset_id=self.dataset_id,
|
|
334
372
|
input_labels=label_input,
|
|
335
373
|
images_path=images_path,
|
|
336
|
-
embed=embed,
|
|
337
374
|
)
|
|
338
375
|
|
|
376
|
+
# Tag samples with split name if provided
|
|
377
|
+
if split is not None and created_sample_ids:
|
|
378
|
+
tag = tag_resolver.get_or_create_sample_tag_by_name(
|
|
379
|
+
session=self.session,
|
|
380
|
+
dataset_id=self.dataset_id,
|
|
381
|
+
tag_name=split,
|
|
382
|
+
)
|
|
383
|
+
tag_resolver.add_sample_ids_to_tag_id(
|
|
384
|
+
session=self.session,
|
|
385
|
+
tag_id=tag.tag_id,
|
|
386
|
+
sample_ids=created_sample_ids,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if embed:
|
|
390
|
+
_generate_embeddings(
|
|
391
|
+
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def add_samples_from_coco_caption(
|
|
395
|
+
self,
|
|
396
|
+
annotations_json: PathLike,
|
|
397
|
+
images_path: PathLike,
|
|
398
|
+
split: str | None = None,
|
|
399
|
+
embed: bool = True,
|
|
400
|
+
) -> None:
|
|
401
|
+
"""Load a dataset in COCO caption format and store in DB.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
annotations_json: Path to the COCO caption JSON file.
|
|
405
|
+
images_path: Path to the folder containing the images.
|
|
406
|
+
split: Optional split name to tag samples (e.g., 'train', 'val').
|
|
407
|
+
If provided, all samples will be tagged with this name.
|
|
408
|
+
embed: If True, generate embeddings for the newly added samples.
|
|
409
|
+
"""
|
|
410
|
+
if isinstance(annotations_json, str):
|
|
411
|
+
annotations_json = Path(annotations_json)
|
|
412
|
+
annotations_json = annotations_json.absolute()
|
|
413
|
+
|
|
414
|
+
if not annotations_json.is_file() or annotations_json.suffix != ".json":
|
|
415
|
+
raise FileNotFoundError(f"COCO caption json file not found: '{annotations_json}'")
|
|
416
|
+
|
|
417
|
+
if isinstance(images_path, str):
|
|
418
|
+
images_path = Path(images_path)
|
|
419
|
+
images_path = images_path.absolute()
|
|
420
|
+
|
|
421
|
+
created_sample_ids = add_samples.load_into_dataset_from_coco_captions(
|
|
422
|
+
session=self.session,
|
|
423
|
+
dataset_id=self.dataset_id,
|
|
424
|
+
annotations_json=annotations_json,
|
|
425
|
+
images_path=images_path,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Tag samples with split name if provided
|
|
429
|
+
if split is not None and created_sample_ids:
|
|
430
|
+
tag = tag_resolver.get_or_create_sample_tag_by_name(
|
|
431
|
+
session=self.session,
|
|
432
|
+
dataset_id=self.dataset_id,
|
|
433
|
+
tag_name=split,
|
|
434
|
+
)
|
|
435
|
+
tag_resolver.add_sample_ids_to_tag_id(
|
|
436
|
+
session=self.session,
|
|
437
|
+
tag_id=tag.tag_id,
|
|
438
|
+
sample_ids=created_sample_ids,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
if embed:
|
|
442
|
+
_generate_embeddings(
|
|
443
|
+
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
|
|
444
|
+
)
|
|
445
|
+
|
|
339
446
|
def compute_typicality_metadata(
|
|
340
447
|
self,
|
|
341
448
|
embedding_model_name: str | None = None,
|
|
@@ -393,3 +500,23 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU
|
|
|
393
500
|
# Mark the embedding search feature as enabled.
|
|
394
501
|
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
|
|
395
502
|
features.lightly_studio_active_features.append("embeddingSearchEnabled")
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _resolve_yolo_splits(data_yaml: Path, input_split: str | None) -> list[str]:
|
|
506
|
+
"""Determine which YOLO splits to process for the given config."""
|
|
507
|
+
if input_split is not None:
|
|
508
|
+
if input_split not in ALLOWED_YOLO_SPLITS:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"Split '{input_split}' not found in config file '{data_yaml}'. "
|
|
511
|
+
f"Allowed splits: {sorted(ALLOWED_YOLO_SPLITS)}"
|
|
512
|
+
)
|
|
513
|
+
return [input_split]
|
|
514
|
+
|
|
515
|
+
with data_yaml.open() as f:
|
|
516
|
+
config = yaml.safe_load(f)
|
|
517
|
+
|
|
518
|
+
config_keys = config.keys() if isinstance(config, dict) else []
|
|
519
|
+
splits = [key for key in config_keys if key in ALLOWED_YOLO_SPLITS]
|
|
520
|
+
if not splits:
|
|
521
|
+
raise ValueError(f"No splits found in config file '{data_yaml}'")
|
|
522
|
+
return splits
|
lightly_studio/dataset/loader.py
CHANGED
|
@@ -258,10 +258,7 @@ class DatasetLoader:
|
|
|
258
258
|
# Create dataset and annotation task.
|
|
259
259
|
dataset = dataset_resolver.create(
|
|
260
260
|
session=self.session,
|
|
261
|
-
dataset=DatasetCreate(
|
|
262
|
-
name=dataset_name,
|
|
263
|
-
directory=str(img_dir_path),
|
|
264
|
-
),
|
|
261
|
+
dataset=DatasetCreate(name=dataset_name),
|
|
265
262
|
)
|
|
266
263
|
|
|
267
264
|
self._load_into_dataset(
|
|
@@ -296,10 +293,7 @@ class DatasetLoader:
|
|
|
296
293
|
# Create dataset.
|
|
297
294
|
dataset = dataset_resolver.create(
|
|
298
295
|
session=self.session,
|
|
299
|
-
dataset=DatasetCreate(
|
|
300
|
-
name=dataset_name,
|
|
301
|
-
directory=img_dir,
|
|
302
|
-
),
|
|
296
|
+
dataset=DatasetCreate(name=dataset_name),
|
|
303
297
|
)
|
|
304
298
|
|
|
305
299
|
# Collect image file paths with extension filtering.
|
lightly_studio/db_manager.py
CHANGED
|
@@ -57,6 +57,11 @@ class DatabaseEngine:
|
|
|
57
57
|
try:
|
|
58
58
|
yield session
|
|
59
59
|
session.commit()
|
|
60
|
+
|
|
61
|
+
# Commit the persistent session to ensure it sees the latest data changes.
|
|
62
|
+
# This prevents the persistent session from having stale data when it's used
|
|
63
|
+
# after operations in short-lived sessions have modified the database.
|
|
64
|
+
self.get_persistent_session().commit()
|
|
60
65
|
except Exception:
|
|
61
66
|
session.rollback()
|
|
62
67
|
raise
|
|
@@ -66,7 +71,9 @@ class DatabaseEngine:
|
|
|
66
71
|
def get_persistent_session(self) -> Session:
|
|
67
72
|
"""Get the persistent database session."""
|
|
68
73
|
if self._persistent_session is None:
|
|
69
|
-
self._persistent_session = Session(
|
|
74
|
+
self._persistent_session = Session(
|
|
75
|
+
self._engine, close_resets_only=False, expire_on_commit=True
|
|
76
|
+
)
|
|
70
77
|
return self._persistent_session
|
|
71
78
|
|
|
72
79
|
|
|
@@ -78,11 +85,10 @@ def get_engine() -> DatabaseEngine:
|
|
|
78
85
|
"""Get the database engine.
|
|
79
86
|
|
|
80
87
|
If the engine does not exist yet, it is newly created with the default settings.
|
|
81
|
-
In that case, a pre-existing database file is deleted.
|
|
82
88
|
"""
|
|
83
89
|
global _engine # noqa: PLW0603
|
|
84
90
|
if _engine is None:
|
|
85
|
-
_engine = DatabaseEngine(
|
|
91
|
+
_engine = DatabaseEngine()
|
|
86
92
|
return _engine
|
|
87
93
|
|
|
88
94
|
|
|
@@ -94,7 +100,7 @@ def set_engine(engine: DatabaseEngine) -> None:
|
|
|
94
100
|
_engine = engine
|
|
95
101
|
|
|
96
102
|
|
|
97
|
-
def connect(db_file: str | None, cleanup_existing: bool = False) -> None:
|
|
103
|
+
def connect(db_file: str | None = None, cleanup_existing: bool = False) -> None:
|
|
98
104
|
"""Set up the database connection.
|
|
99
105
|
|
|
100
106
|
Helper function to set up the database engine.
|