lamindb 0.71.2__py3-none-any.whl → 0.72.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lamindb/__init__.py +2 -2
- lamindb/_annotate.py +6 -10
- lamindb/_artifact.py +24 -10
- lamindb/_can_validate.py +9 -3
- lamindb/_collection.py +7 -7
- lamindb/_feature.py +53 -45
- lamindb/_feature_set.py +37 -74
- lamindb/_from_values.py +27 -8
- lamindb/_query_manager.py +6 -1
- lamindb/_registry.py +60 -100
- lamindb/_run.py +0 -2
- lamindb/_save.py +28 -11
- lamindb/core/__init__.py +4 -0
- lamindb/core/_data.py +56 -30
- lamindb/core/_feature_manager.py +159 -64
- lamindb/core/_label_manager.py +53 -38
- lamindb/core/_run_context.py +24 -1
- lamindb/core/datasets/_core.py +10 -18
- lamindb/core/schema.py +53 -0
- {lamindb-0.71.2.dist-info → lamindb-0.72.0.dist-info}/METADATA +7 -6
- {lamindb-0.71.2.dist-info → lamindb-0.72.0.dist-info}/RECORD +23 -22
- {lamindb-0.71.2.dist-info → lamindb-0.72.0.dist-info}/LICENSE +0 -0
- {lamindb-0.71.2.dist-info → lamindb-0.72.0.dist-info}/WHEEL +0 -0
lamindb/_from_values.py
CHANGED
@@ -185,8 +185,15 @@ def create_records_from_public(
|
|
185
185
|
|
186
186
|
# create the corresponding bionty object from model
|
187
187
|
try:
|
188
|
+
# TODO: more generic
|
189
|
+
organism = kwargs.get("organism")
|
190
|
+
if field.field.name == "ensembl_gene_id":
|
191
|
+
if iterable_idx[0].startswith("ENSG"):
|
192
|
+
organism = "human"
|
193
|
+
elif iterable_idx[0].startswith("ENSMUSG"):
|
194
|
+
organism = "mouse"
|
188
195
|
public_ontology = model.public(
|
189
|
-
organism=
|
196
|
+
organism=organism, public_source=kwargs.get("public_source")
|
190
197
|
)
|
191
198
|
except Exception:
|
192
199
|
# for custom records that are not created from public sources
|
@@ -223,8 +230,15 @@ def create_records_from_public(
|
|
223
230
|
bionty_kwargs, multi_msg = _bulk_create_dicts_from_df(
|
224
231
|
keys=mapped_values, column_name=field.field.name, df=bionty_df
|
225
232
|
)
|
233
|
+
organism_kwargs = {}
|
234
|
+
if "organism" not in kwargs:
|
235
|
+
organism_record = _get_organism_record(
|
236
|
+
field, public_ontology.organism, force=True
|
237
|
+
)
|
238
|
+
if organism_record is not None:
|
239
|
+
organism_kwargs["organism"] = organism_record
|
226
240
|
for bk in bionty_kwargs:
|
227
|
-
records.append(model(**bk, **kwargs))
|
241
|
+
records.append(model(**bk, **kwargs, **organism_kwargs))
|
228
242
|
|
229
243
|
# number of records that matches field (not synonyms)
|
230
244
|
validated = result.validated
|
@@ -260,10 +274,11 @@ def index_iterable(iterable: Iterable) -> pd.Index:
|
|
260
274
|
return idx[(idx != "") & (~idx.isnull())]
|
261
275
|
|
262
276
|
|
263
|
-
def _print_values(names:
|
264
|
-
names =
|
265
|
-
|
266
|
-
|
277
|
+
def _print_values(names: Iterable, n: int = 20) -> str:
|
278
|
+
names = (name for name in names if name != "None")
|
279
|
+
unique_names = list(dict.fromkeys(names))[:n]
|
280
|
+
print_values = ", ".join(f"'{name}'" for name in unique_names)
|
281
|
+
if len(unique_names) > n:
|
267
282
|
print_values += ", ..."
|
268
283
|
return print_values
|
269
284
|
|
@@ -334,9 +349,13 @@ def _has_organism_field(orm: Registry) -> bool:
|
|
334
349
|
return False
|
335
350
|
|
336
351
|
|
337
|
-
def _get_organism_record(
|
352
|
+
def _get_organism_record(
|
353
|
+
field: StrField, organism: str | Registry, force: bool = False
|
354
|
+
) -> Registry:
|
338
355
|
model = field.field.model
|
339
|
-
if
|
356
|
+
check = True if force else field.field.name != "ensembl_gene_id"
|
357
|
+
|
358
|
+
if _has_organism_field(model) and check:
|
340
359
|
from lnschema_bionty._bionty import create_or_get_organism_record
|
341
360
|
|
342
361
|
organism_record = create_or_get_organism_record(organism=organism, orm=model)
|
lamindb/_query_manager.py
CHANGED
@@ -7,6 +7,8 @@ from lamin_utils import logger
|
|
7
7
|
from lamindb_setup.core._docs import doc_args
|
8
8
|
from lnschema_core.models import Registry
|
9
9
|
|
10
|
+
from lamindb.core._settings import settings
|
11
|
+
|
10
12
|
from .core._feature_manager import get_feature_set_by_slot
|
11
13
|
|
12
14
|
if TYPE_CHECKING:
|
@@ -41,7 +43,10 @@ class QueryManager(models.Manager):
|
|
41
43
|
from lamindb.core._data import WARNING_RUN_TRANSFORM, _track_run_input
|
42
44
|
from lamindb.core._run_context import run_context
|
43
45
|
|
44
|
-
if
|
46
|
+
if (
|
47
|
+
run_context.run is None
|
48
|
+
and not settings.silence_file_run_transform_warning
|
49
|
+
):
|
45
50
|
logger.warning(WARNING_RUN_TRANSFORM)
|
46
51
|
_track_run_input(self.instance)
|
47
52
|
|
lamindb/_registry.py
CHANGED
@@ -2,14 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import builtins
|
4
4
|
from typing import TYPE_CHECKING, Iterable, List, NamedTuple
|
5
|
-
from uuid import UUID
|
6
5
|
|
7
6
|
import dj_database_url
|
8
7
|
import lamindb_setup as ln_setup
|
9
|
-
import pandas as pd
|
10
8
|
from django.core.exceptions import FieldDoesNotExist
|
11
9
|
from django.db import connections
|
12
|
-
from django.db.models import Manager, QuerySet
|
10
|
+
from django.db.models import Manager, Q, QuerySet
|
13
11
|
from lamin_utils import logger
|
14
12
|
from lamin_utils._lookup import Lookup
|
15
13
|
from lamin_utils._search import search as base_search
|
@@ -22,19 +20,17 @@ from lnschema_core import Registry
|
|
22
20
|
|
23
21
|
from lamindb._utils import attach_func_to_class_method
|
24
22
|
from lamindb.core._settings import settings
|
23
|
+
from lamindb.core.exceptions import ValidationError
|
25
24
|
|
26
25
|
from ._from_values import get_or_create_records
|
27
26
|
|
28
27
|
if TYPE_CHECKING:
|
28
|
+
import pandas as pd
|
29
29
|
from lnschema_core.types import ListLike, StrField
|
30
30
|
|
31
31
|
IPYTHON = getattr(builtins, "__IPYTHON__", False)
|
32
32
|
|
33
33
|
|
34
|
-
class ValidationError(Exception):
|
35
|
-
pass
|
36
|
-
|
37
|
-
|
38
34
|
def init_self_from_db(self: Registry, existing_record: Registry):
|
39
35
|
new_args = [
|
40
36
|
getattr(existing_record, field.attname) for field in self._meta.concrete_fields
|
@@ -61,20 +57,15 @@ def suggest_objects_with_same_name(orm: Registry, kwargs) -> str | None:
|
|
61
57
|
if kwargs.get("name") is None:
|
62
58
|
return None
|
63
59
|
else:
|
64
|
-
|
65
|
-
if
|
60
|
+
queryset = orm.search(kwargs["name"])
|
61
|
+
if not queryset.exists(): # empty queryset
|
66
62
|
return None
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
# test for exact match
|
72
|
-
if len(results) > 0:
|
73
|
-
if results.index[0] == kwargs["name"]:
|
74
|
-
return "object-with-same-name-exists"
|
63
|
+
else:
|
64
|
+
for record in queryset:
|
65
|
+
if record.name == kwargs["name"]:
|
66
|
+
return "object-with-same-name-exists"
|
75
67
|
else:
|
76
|
-
s = "" if
|
77
|
-
it = "it" if results.shape[0] == 1 else "one of them"
|
68
|
+
s, it = ("", "it") if len(queryset) == 1 else ("s", "one of them")
|
78
69
|
msg = (
|
79
70
|
f"record{s} with similar name{s} exist! did you mean to load {it}?"
|
80
71
|
)
|
@@ -83,9 +74,9 @@ def suggest_objects_with_same_name(orm: Registry, kwargs) -> str | None:
|
|
83
74
|
|
84
75
|
logger.warning(f"{msg}")
|
85
76
|
if settings._verbosity_int >= 1:
|
86
|
-
display(
|
77
|
+
display(queryset.df())
|
87
78
|
else:
|
88
|
-
logger.warning(f"{msg}\n{
|
79
|
+
logger.warning(f"{msg}\n{queryset}")
|
89
80
|
return None
|
90
81
|
|
91
82
|
|
@@ -162,80 +153,42 @@ def _search(
|
|
162
153
|
string: str,
|
163
154
|
*,
|
164
155
|
field: StrField | list[StrField] | None = None,
|
165
|
-
limit: int | None =
|
166
|
-
return_queryset: bool = False,
|
156
|
+
limit: int | None = 20,
|
167
157
|
case_sensitive: bool = False,
|
168
|
-
synonyms_field: StrField | None = "synonyms",
|
169
158
|
using_key: str | None = None,
|
170
|
-
) ->
|
171
|
-
|
172
|
-
orm =
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
field = get_default_str_field(orm=orm, field=field)
|
180
|
-
|
181
|
-
try:
|
182
|
-
orm._meta.get_field(synonyms_field)
|
183
|
-
synonyms_field_exists = True
|
184
|
-
except FieldDoesNotExist:
|
185
|
-
synonyms_field_exists = False
|
186
|
-
|
187
|
-
if synonyms_field is not None and synonyms_field_exists:
|
188
|
-
df = pd.DataFrame(queryset.values("uid", field, synonyms_field))
|
189
|
-
else:
|
190
|
-
df = pd.DataFrame(queryset.values("uid", field))
|
191
|
-
|
192
|
-
return base_search(
|
193
|
-
df=df,
|
194
|
-
string=string,
|
195
|
-
field=field,
|
196
|
-
limit=limit,
|
197
|
-
synonyms_field=str(synonyms_field),
|
198
|
-
case_sensitive=case_sensitive,
|
199
|
-
)
|
200
|
-
|
201
|
-
# search in both key and description fields for Artifact
|
202
|
-
if orm._meta.model.__name__ == "Artifact" and field is None:
|
203
|
-
field = ["key", "description"]
|
204
|
-
|
205
|
-
if not isinstance(field, List):
|
206
|
-
field = [field]
|
207
|
-
|
208
|
-
results = []
|
209
|
-
for fd in field:
|
210
|
-
result_field = _search_single_field(
|
211
|
-
string=string, field=fd, synonyms_field=synonyms_field
|
212
|
-
)
|
213
|
-
results.append(result_field)
|
214
|
-
# turn off synonyms search after the 1st field
|
215
|
-
synonyms_field = None
|
216
|
-
|
217
|
-
if len(results) > 1:
|
218
|
-
result = (
|
219
|
-
pd.concat([r.reset_index() for r in results], join="outer")
|
220
|
-
.drop(columns=["index"], errors="ignore")
|
221
|
-
.set_index("uid")
|
222
|
-
)
|
223
|
-
else:
|
224
|
-
result = results[0]
|
225
|
-
|
226
|
-
# remove results that have __ratio__ 0
|
227
|
-
if "__ratio__" in result.columns:
|
228
|
-
result = result[result["__ratio__"] > 0].sort_values(
|
229
|
-
"__ratio__", ascending=False
|
230
|
-
)
|
231
|
-
# restrict to 1 decimal
|
232
|
-
# move the score to be the last column
|
233
|
-
result["score"] = result.pop("__ratio__").round(1)
|
234
|
-
|
235
|
-
if return_queryset:
|
236
|
-
return _order_queryset_by_ids(queryset, result.reset_index()["uid"])
|
159
|
+
) -> QuerySet:
|
160
|
+
input_queryset = _queryset(cls, using_key=using_key)
|
161
|
+
orm = input_queryset.model
|
162
|
+
if field is None:
|
163
|
+
fields = [
|
164
|
+
field.name
|
165
|
+
for field in orm._meta.fields
|
166
|
+
if field.get_internal_type() in {"CharField", "TextField"}
|
167
|
+
]
|
237
168
|
else:
|
238
|
-
|
169
|
+
if not isinstance(field, list):
|
170
|
+
fields_input = [field]
|
171
|
+
else:
|
172
|
+
fields_input = field
|
173
|
+
fields = []
|
174
|
+
for field in fields_input:
|
175
|
+
if not isinstance(field, str):
|
176
|
+
try:
|
177
|
+
fields.append(field.field.name)
|
178
|
+
except AttributeError as error:
|
179
|
+
raise TypeError(
|
180
|
+
"Please pass a Registry string field, e.g., `CellType.name`!"
|
181
|
+
) from error
|
182
|
+
else:
|
183
|
+
fields.append(field)
|
184
|
+
expression = Q()
|
185
|
+
case_sensitive_i = "" if case_sensitive else "i"
|
186
|
+
for field in fields:
|
187
|
+
# Construct the keyword for the Q object dynamically
|
188
|
+
query = {f"{field}__{case_sensitive_i}contains": string}
|
189
|
+
expression |= Q(**query) # Unpack the dictionary into Q()
|
190
|
+
output_queryset = input_queryset.filter(expression)[:limit]
|
191
|
+
return output_queryset
|
239
192
|
|
240
193
|
|
241
194
|
@classmethod # type: ignore
|
@@ -246,19 +199,15 @@ def search(
|
|
246
199
|
*,
|
247
200
|
field: StrField | None = None,
|
248
201
|
limit: int | None = 20,
|
249
|
-
return_queryset: bool = False,
|
250
202
|
case_sensitive: bool = False,
|
251
|
-
|
252
|
-
) -> pd.DataFrame | QuerySet:
|
203
|
+
) -> QuerySet:
|
253
204
|
"""{}."""
|
254
205
|
return _search(
|
255
206
|
cls=cls,
|
256
207
|
string=string,
|
257
208
|
field=field,
|
258
|
-
return_queryset=return_queryset,
|
259
209
|
limit=limit,
|
260
210
|
case_sensitive=case_sensitive,
|
261
|
-
synonyms_field=synonyms_field,
|
262
211
|
)
|
263
212
|
|
264
213
|
|
@@ -470,7 +419,8 @@ def transfer_to_default_db(
|
|
470
419
|
if run_context.run is not None:
|
471
420
|
record.run_id = run_context.run.id
|
472
421
|
else:
|
473
|
-
|
422
|
+
if not settings.silence_file_run_transform_warning:
|
423
|
+
logger.warning(WARNING_RUN_TRANSFORM)
|
474
424
|
record.run_id = None
|
475
425
|
if hasattr(record, "transform_id") and record._meta.model_name != "run":
|
476
426
|
record.transform = None
|
@@ -535,7 +485,13 @@ def save(self, *args, **kwargs) -> Registry:
|
|
535
485
|
self_on_db._state.db = db
|
536
486
|
self_on_db.pk = pk_on_db
|
537
487
|
# by default, transfer parents of the labels to maintain ontological hierarchy
|
538
|
-
|
488
|
+
try:
|
489
|
+
import bionty as bt
|
490
|
+
|
491
|
+
parents = kwargs.get("parents", bt.settings.auto_save_parents)
|
492
|
+
except ImportError:
|
493
|
+
parents = kwargs.get("parents", True)
|
494
|
+
add_from_kwargs = {"parents": parents}
|
539
495
|
logger.info("transfer features")
|
540
496
|
self.features._add_from(self_on_db, **add_from_kwargs)
|
541
497
|
logger.info("transfer labels")
|
@@ -575,7 +531,11 @@ def __get_schema_name__(cls) -> str:
|
|
575
531
|
@classmethod # type: ignore
|
576
532
|
def __get_name_with_schema__(cls) -> str:
|
577
533
|
schema_name = cls.__get_schema_name__()
|
578
|
-
|
534
|
+
if schema_name == "core":
|
535
|
+
schema_prefix = ""
|
536
|
+
else:
|
537
|
+
schema_prefix = f"{schema_name}."
|
538
|
+
return f"{schema_prefix}{cls.__name__}"
|
579
539
|
|
580
540
|
|
581
541
|
Registry.__get_schema_name__ = __get_schema_name__
|
lamindb/_run.py
CHANGED
@@ -13,7 +13,6 @@ def __init__(run: Run, *args, **kwargs):
|
|
13
13
|
transform: Transform = None
|
14
14
|
if "transform" in kwargs or len(args) == 1:
|
15
15
|
transform = kwargs.pop("transform") if len(args) == 0 else args[0]
|
16
|
-
params: str | None = kwargs.pop("params") if "params" in kwargs else None
|
17
16
|
reference: str | None = kwargs.pop("reference") if "reference" in kwargs else None
|
18
17
|
reference_type: str | None = (
|
19
18
|
kwargs.pop("reference_type") if "reference_type" in kwargs else None
|
@@ -26,7 +25,6 @@ def __init__(run: Run, *args, **kwargs):
|
|
26
25
|
transform=transform,
|
27
26
|
reference=reference,
|
28
27
|
reference_type=reference_type,
|
29
|
-
json=params,
|
30
28
|
)
|
31
29
|
|
32
30
|
|
lamindb/_save.py
CHANGED
@@ -9,10 +9,10 @@ from functools import partial
|
|
9
9
|
from typing import TYPE_CHECKING, Iterable, overload
|
10
10
|
|
11
11
|
import lamindb_setup
|
12
|
-
from django.db import transaction
|
12
|
+
from django.db import IntegrityError, transaction
|
13
13
|
from django.utils.functional import partition
|
14
14
|
from lamin_utils import logger
|
15
|
-
from lamindb_setup.core.upath import
|
15
|
+
from lamindb_setup.core.upath import LocalPathClasses
|
16
16
|
from lnschema_core.models import Artifact, Registry
|
17
17
|
|
18
18
|
from lamindb.core._settings import settings
|
@@ -78,14 +78,15 @@ def save(
|
|
78
78
|
# for artifacts, we want to bulk-upload rather than upload one-by-one
|
79
79
|
non_artifacts, artifacts = partition(lambda r: isinstance(r, Artifact), records)
|
80
80
|
if non_artifacts:
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
bulk_create(
|
81
|
+
non_artifacts_old, non_artifacts_new = partition(
|
82
|
+
lambda r: r._state.adding or r.pk is None, non_artifacts
|
83
|
+
)
|
84
|
+
bulk_create(non_artifacts_new, ignore_conflicts=ignore_conflicts)
|
85
|
+
if non_artifacts_old:
|
86
|
+
bulk_update(non_artifacts_old)
|
85
87
|
non_artifacts_with_parents = [
|
86
|
-
r for r in
|
88
|
+
r for r in non_artifacts_new if hasattr(r, "_parents")
|
87
89
|
]
|
88
|
-
|
89
90
|
if len(non_artifacts_with_parents) > 0 and kwargs.get("parents") is not False:
|
90
91
|
# this can only happen within lnschema_bionty right now!!
|
91
92
|
# we might extend to core lamindb later
|
@@ -129,6 +130,19 @@ def bulk_create(records: Iterable[Registry], ignore_conflicts: bool | None = Fal
|
|
129
130
|
orm.objects.bulk_create(records, ignore_conflicts=ignore_conflicts)
|
130
131
|
|
131
132
|
|
133
|
+
def bulk_update(records: Iterable[Registry], ignore_conflicts: bool | None = False):
|
134
|
+
records_by_orm = defaultdict(list)
|
135
|
+
for record in records:
|
136
|
+
records_by_orm[record.__class__].append(record)
|
137
|
+
for orm, records in records_by_orm.items():
|
138
|
+
field_names = [
|
139
|
+
field.name
|
140
|
+
for field in orm._meta.fields
|
141
|
+
if (field.name != "created_at" and field.name != "id")
|
142
|
+
]
|
143
|
+
orm.objects.bulk_update(records, field_names)
|
144
|
+
|
145
|
+
|
132
146
|
# This is also used within Artifact.save()
|
133
147
|
def check_and_attempt_upload(
|
134
148
|
artifact: Artifact,
|
@@ -166,9 +180,12 @@ def copy_or_move_to_cache(artifact: Artifact, storage_path: UPath):
|
|
166
180
|
is_dir = local_path.is_dir()
|
167
181
|
cache_dir = settings._storage_settings.cache_dir
|
168
182
|
|
169
|
-
# just delete from the cache dir if
|
170
|
-
if
|
171
|
-
if
|
183
|
+
# just delete from the cache dir if storage_path is local
|
184
|
+
if isinstance(storage_path, LocalPathClasses):
|
185
|
+
if (
|
186
|
+
local_path.as_posix() != storage_path.as_posix()
|
187
|
+
and cache_dir in local_path.parents
|
188
|
+
):
|
172
189
|
if is_dir:
|
173
190
|
shutil.rmtree(local_path)
|
174
191
|
else:
|
lamindb/core/__init__.py
CHANGED
@@ -15,6 +15,8 @@ Registries:
|
|
15
15
|
IsVersioned
|
16
16
|
CanValidate
|
17
17
|
HasParents
|
18
|
+
TracksRun
|
19
|
+
TracksUpdates
|
18
20
|
InspectResult
|
19
21
|
fields
|
20
22
|
|
@@ -56,6 +58,8 @@ from lnschema_core.models import (
|
|
56
58
|
HasParents,
|
57
59
|
IsVersioned,
|
58
60
|
Registry,
|
61
|
+
TracksRun,
|
62
|
+
TracksUpdates,
|
59
63
|
)
|
60
64
|
|
61
65
|
from lamindb._annotate import (
|
lamindb/core/_data.py
CHANGED
@@ -18,10 +18,6 @@ from lnschema_core.models import (
|
|
18
18
|
format_field_value,
|
19
19
|
)
|
20
20
|
|
21
|
-
from lamindb._feature_set import (
|
22
|
-
dict_related_model_to_related_name,
|
23
|
-
dict_schema_name_to_model_name,
|
24
|
-
)
|
25
21
|
from lamindb._parents import view_lineage
|
26
22
|
from lamindb._query_set import QuerySet
|
27
23
|
from lamindb.core._settings import settings
|
@@ -36,6 +32,10 @@ from ._feature_manager import (
|
|
36
32
|
from ._label_manager import LabelManager, print_labels
|
37
33
|
from ._run_context import run_context
|
38
34
|
from .exceptions import ValidationError
|
35
|
+
from .schema import (
|
36
|
+
dict_related_model_to_related_name,
|
37
|
+
dict_schema_name_to_model_name,
|
38
|
+
)
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from lnschema_core.types import StrField
|
@@ -87,7 +87,7 @@ def save_feature_set_links(self: Artifact | Collection) -> None:
|
|
87
87
|
for slot, feature_set in self._feature_sets.items():
|
88
88
|
kwargs = {
|
89
89
|
host_id_field: self.id,
|
90
|
-
"
|
90
|
+
"featureset_id": feature_set.id,
|
91
91
|
"slot": slot,
|
92
92
|
}
|
93
93
|
links.append(Data.feature_sets.through(**kwargs))
|
@@ -114,6 +114,16 @@ def format_repr(value: Registry, exclude: list[str] | str | None = None) -> str:
|
|
114
114
|
@doc_args(Data.describe.__doc__)
|
115
115
|
def describe(self: Data):
|
116
116
|
"""{}."""
|
117
|
+
# prefetch all many-to-many relationships
|
118
|
+
# doesn't work for describing using artifact
|
119
|
+
# self = (
|
120
|
+
# self.__class__.objects.using(self._state.db)
|
121
|
+
# .prefetch_related(
|
122
|
+
# *[f.name for f in self.__class__._meta.get_fields() if f.many_to_many]
|
123
|
+
# )
|
124
|
+
# .get(id=self.id)
|
125
|
+
# )
|
126
|
+
|
117
127
|
model_name = self.__class__.__name__
|
118
128
|
msg = ""
|
119
129
|
|
@@ -125,6 +135,19 @@ def describe(self: Data):
|
|
125
135
|
foreign_key_fields.append(f.name)
|
126
136
|
else:
|
127
137
|
direct_fields.append(f.name)
|
138
|
+
if not self._state.adding:
|
139
|
+
# prefetch foreign key relationships
|
140
|
+
self = (
|
141
|
+
self.__class__.objects.using(self._state.db)
|
142
|
+
.select_related(*foreign_key_fields)
|
143
|
+
.get(id=self.id)
|
144
|
+
)
|
145
|
+
# prefetch m-2-m relationships
|
146
|
+
self = (
|
147
|
+
self.__class__.objects.using(self._state.db)
|
148
|
+
.prefetch_related("feature_sets", "input_of")
|
149
|
+
.get(id=self.id)
|
150
|
+
)
|
128
151
|
|
129
152
|
# provenance
|
130
153
|
if len(foreign_key_fields) > 0: # always True for Artifact and Collection
|
@@ -152,16 +175,13 @@ def describe(self: Data):
|
|
152
175
|
|
153
176
|
|
154
177
|
def validate_feature(feature: Feature, records: list[Registry]) -> None:
|
155
|
-
"""Validate feature record,
|
178
|
+
"""Validate feature record, adjust feature.dtype based on labels records."""
|
156
179
|
if not isinstance(feature, Feature):
|
157
180
|
raise TypeError("feature has to be of type Feature")
|
158
181
|
if feature._state.adding:
|
159
182
|
registries = {record.__class__.__get_name_with_schema__() for record in records}
|
160
183
|
registries_str = "|".join(registries)
|
161
|
-
msg = (
|
162
|
-
f"ln.Feature(name='{feature.name}', type='category',"
|
163
|
-
f" registries='{registries_str}').save()"
|
164
|
-
)
|
184
|
+
msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
|
165
185
|
raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
|
166
186
|
|
167
187
|
|
@@ -174,9 +194,9 @@ def get_labels(
|
|
174
194
|
"""{}."""
|
175
195
|
if not isinstance(feature, Feature):
|
176
196
|
raise TypeError("feature has to be of type Feature")
|
177
|
-
if feature.
|
197
|
+
if feature.dtype is None or not feature.dtype.startswith("cat["):
|
178
198
|
raise ValueError("feature does not have linked labels")
|
179
|
-
registries_to_check = feature.
|
199
|
+
registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
|
180
200
|
if len(registries_to_check) > 1 and not mute:
|
181
201
|
logger.warning("labels come from multiple registries!")
|
182
202
|
# return an empty query set if self.id is still None
|
@@ -186,15 +206,15 @@ def get_labels(
|
|
186
206
|
for registry in registries_to_check:
|
187
207
|
# currently need to distinguish between ULabel and non-ULabel, because
|
188
208
|
# we only have the feature information for Label
|
189
|
-
if registry == "
|
209
|
+
if registry == "ULabel":
|
190
210
|
links_to_labels = get_label_links(self, registry, feature)
|
191
211
|
label_ids = [link.ulabel_id for link in links_to_labels]
|
192
212
|
qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
|
193
213
|
id__in=label_ids
|
194
214
|
)
|
195
|
-
|
215
|
+
elif registry in self.features.accessor_by_orm:
|
196
216
|
qs_by_registry[registry] = getattr(
|
197
|
-
self, self.features.
|
217
|
+
self, self.features.accessor_by_orm[registry]
|
198
218
|
).all()
|
199
219
|
if flat_names:
|
200
220
|
# returns a flat list of names
|
@@ -204,7 +224,7 @@ def get_labels(
|
|
204
224
|
for v in qs_by_registry.values():
|
205
225
|
values += v.list(get_default_str_field(v))
|
206
226
|
return values
|
207
|
-
if len(registries_to_check) == 1:
|
227
|
+
if len(registries_to_check) == 1 and registry in qs_by_registry:
|
208
228
|
return qs_by_registry[registry]
|
209
229
|
else:
|
210
230
|
return qs_by_registry
|
@@ -238,9 +258,9 @@ def add_labels(
|
|
238
258
|
"Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
|
239
259
|
" feature=ln.Feature(name='my_feature'))"
|
240
260
|
)
|
241
|
-
if feature.
|
261
|
+
if feature.dtype.startswith("cat["):
|
242
262
|
orm_dict = dict_schema_name_to_model_name(Artifact)
|
243
|
-
for reg in feature.
|
263
|
+
for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
|
244
264
|
orm = orm_dict.get(reg)
|
245
265
|
records_validated += orm.from_values(records, field=field)
|
246
266
|
|
@@ -281,8 +301,11 @@ def add_labels(
|
|
281
301
|
record
|
282
302
|
)
|
283
303
|
for registry_name, records in records_by_registry.items():
|
304
|
+
if registry_name not in self.features.accessor_by_orm:
|
305
|
+
logger.warning(f"skipping {registry_name}")
|
306
|
+
continue
|
284
307
|
labels_accessor = getattr(
|
285
|
-
self, self.features.
|
308
|
+
self, self.features.accessor_by_orm[registry_name]
|
286
309
|
)
|
287
310
|
# remove labels that are already linked as add doesn't perform update
|
288
311
|
linked_labels = [r for r in records if r in labels_accessor.filter()]
|
@@ -290,26 +313,29 @@ def add_labels(
|
|
290
313
|
labels_accessor.remove(*linked_labels)
|
291
314
|
labels_accessor.add(*records, through_defaults={"feature_id": feature.id})
|
292
315
|
feature_set_links = get_feature_set_links(self)
|
293
|
-
feature_set_ids = [link.
|
316
|
+
feature_set_ids = [link.featureset_id for link in feature_set_links.all()]
|
294
317
|
# get all linked features of type Feature
|
295
318
|
feature_sets = FeatureSet.filter(id__in=feature_set_ids).all()
|
296
319
|
linked_features_by_slot = {
|
297
|
-
feature_set_links.filter(
|
320
|
+
feature_set_links.filter(featureset_id=feature_set.id)
|
298
321
|
.one()
|
299
322
|
.slot: feature_set.features.all()
|
300
323
|
for feature_set in feature_sets
|
301
|
-
if "
|
324
|
+
if "Feature" == feature_set.registry
|
302
325
|
}
|
303
326
|
for registry_name, _ in records_by_registry.items():
|
304
327
|
msg = ""
|
305
|
-
if
|
328
|
+
if (
|
329
|
+
not feature.dtype.startswith("cat[")
|
330
|
+
or registry_name not in feature.dtype
|
331
|
+
):
|
306
332
|
if len(msg) > 0:
|
307
333
|
msg += ", "
|
308
334
|
msg += f"linked feature '{feature.name}' to registry '{registry_name}'"
|
309
|
-
if feature.
|
310
|
-
feature.
|
311
|
-
elif registry_name not in feature.
|
312
|
-
feature.
|
335
|
+
if not feature.dtype.startswith("cat["):
|
336
|
+
feature.dtype = f"cat[{registry_name}]"
|
337
|
+
elif registry_name not in feature.dtype:
|
338
|
+
feature.dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
|
313
339
|
feature.save()
|
314
340
|
if len(msg) > 0:
|
315
341
|
logger.save(msg)
|
@@ -321,7 +347,7 @@ def add_labels(
|
|
321
347
|
found_feature = True
|
322
348
|
if not found_feature:
|
323
349
|
if "external" in linked_features_by_slot:
|
324
|
-
feature_set = self.features.
|
350
|
+
feature_set = self.features.feature_set_by_slot["external"]
|
325
351
|
features_list = feature_set.features.list()
|
326
352
|
else:
|
327
353
|
features_list = []
|
@@ -334,11 +360,11 @@ def add_labels(
|
|
334
360
|
).one()
|
335
361
|
old_feature_set_link.delete()
|
336
362
|
remaining_links = self.feature_sets.through.objects.filter(
|
337
|
-
|
363
|
+
featureset_id=feature_set.id
|
338
364
|
).all()
|
339
365
|
if len(remaining_links) == 0:
|
340
366
|
old_feature_set = FeatureSet.filter(
|
341
|
-
id=old_feature_set_link.
|
367
|
+
id=old_feature_set_link.featureset_id
|
342
368
|
).one()
|
343
369
|
logger.info(
|
344
370
|
"nothing links to it anymore, deleting feature set"
|