lamindb 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lamindb/__init__.py +33 -26
- lamindb/_finish.py +9 -1
- lamindb/_tracked.py +26 -3
- lamindb/_view.py +2 -3
- lamindb/base/__init__.py +1 -1
- lamindb/base/ids.py +1 -10
- lamindb/base/users.py +1 -4
- lamindb/core/__init__.py +7 -65
- lamindb/core/_compat.py +60 -0
- lamindb/core/_context.py +50 -22
- lamindb/core/_mapped_collection.py +4 -2
- lamindb/core/_settings.py +6 -6
- lamindb/core/_sync_git.py +1 -1
- lamindb/core/_track_environment.py +2 -1
- lamindb/core/datasets/_small.py +3 -3
- lamindb/core/loaders.py +43 -20
- lamindb/core/storage/_anndata_accessor.py +8 -3
- lamindb/core/storage/_backed_access.py +14 -7
- lamindb/core/storage/_pyarrow_dataset.py +24 -9
- lamindb/core/storage/_tiledbsoma.py +8 -6
- lamindb/core/storage/_zarr.py +104 -25
- lamindb/core/storage/objects.py +63 -28
- lamindb/core/storage/paths.py +16 -13
- lamindb/core/types.py +10 -0
- lamindb/curators/__init__.py +176 -149
- lamindb/errors.py +1 -1
- lamindb/integrations/_vitessce.py +4 -4
- lamindb/migrations/0089_subsequent_runs.py +159 -0
- lamindb/migrations/0090_runproject_project_runs.py +73 -0
- lamindb/migrations/{0088_squashed.py → 0090_squashed.py} +245 -177
- lamindb/models/__init__.py +79 -0
- lamindb/{core → models}/_describe.py +3 -3
- lamindb/{core → models}/_django.py +8 -5
- lamindb/{core → models}/_feature_manager.py +103 -87
- lamindb/{_from_values.py → models/_from_values.py} +5 -2
- lamindb/{core/versioning.py → models/_is_versioned.py} +94 -6
- lamindb/{core → models}/_label_manager.py +10 -17
- lamindb/{core/relations.py → models/_relations.py} +8 -1
- lamindb/models/artifact.py +2602 -0
- lamindb/{_can_curate.py → models/can_curate.py} +349 -180
- lamindb/models/collection.py +683 -0
- lamindb/models/core.py +135 -0
- lamindb/models/feature.py +643 -0
- lamindb/models/flextable.py +163 -0
- lamindb/{_parents.py → models/has_parents.py} +55 -49
- lamindb/models/project.py +384 -0
- lamindb/{_query_manager.py → models/query_manager.py} +10 -8
- lamindb/{_query_set.py → models/query_set.py} +64 -32
- lamindb/models/record.py +1762 -0
- lamindb/models/run.py +563 -0
- lamindb/{_save.py → models/save.py} +18 -8
- lamindb/models/schema.py +732 -0
- lamindb/models/transform.py +360 -0
- lamindb/models/ulabel.py +249 -0
- {lamindb-1.1.0.dist-info → lamindb-1.2.0.dist-info}/METADATA +6 -6
- lamindb-1.2.0.dist-info/RECORD +95 -0
- lamindb/_artifact.py +0 -1361
- lamindb/_collection.py +0 -440
- lamindb/_feature.py +0 -316
- lamindb/_is_versioned.py +0 -40
- lamindb/_record.py +0 -1065
- lamindb/_run.py +0 -60
- lamindb/_schema.py +0 -347
- lamindb/_storage.py +0 -15
- lamindb/_transform.py +0 -170
- lamindb/_ulabel.py +0 -56
- lamindb/_utils.py +0 -9
- lamindb/base/validation.py +0 -63
- lamindb/core/_data.py +0 -491
- lamindb/core/fields.py +0 -12
- lamindb/models.py +0 -4435
- lamindb-1.1.0.dist-info/RECORD +0 -95
- {lamindb-1.1.0.dist-info → lamindb-1.2.0.dist-info}/LICENSE +0 -0
- {lamindb-1.1.0.dist-info → lamindb-1.2.0.dist-info}/WHEEL +0 -0
lamindb/base/validation.py
DELETED
@@ -1,63 +0,0 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, get_type_hints
|
2
|
-
|
3
|
-
from lamin_utils import colors
|
4
|
-
|
5
|
-
from lamindb.errors import FieldValidationError
|
6
|
-
|
7
|
-
if TYPE_CHECKING:
|
8
|
-
from .models import Record
|
9
|
-
|
10
|
-
|
11
|
-
def validate_literal_fields(record: "Record", kwargs) -> None:
|
12
|
-
"""Validate all Literal type fields in a record.
|
13
|
-
|
14
|
-
Args:
|
15
|
-
record: record being validated
|
16
|
-
|
17
|
-
Raises:
|
18
|
-
ValidationError: If any field value is not in its Literal's allowed values
|
19
|
-
"""
|
20
|
-
# check is based on string to avoid circular imports
|
21
|
-
if record.__class__.__name__ == "Feature":
|
22
|
-
# the FeatureDtype is more complicated than a simple literal
|
23
|
-
# because it allows constructs like cat[ULabel] etc.
|
24
|
-
# the User model is used at startup and throws a datetime-related error otherwise
|
25
|
-
# simmilar for Storage & Source
|
26
|
-
return None
|
27
|
-
try:
|
28
|
-
type_hints = get_type_hints(record.__class__)
|
29
|
-
except TypeError:
|
30
|
-
# for 3.9, get_type_hints errors with | in type hints
|
31
|
-
return
|
32
|
-
errors = {}
|
33
|
-
|
34
|
-
for field_name, field_type in type_hints.items():
|
35
|
-
# Handle both plain Literal and Union/Optional Literal types
|
36
|
-
origin = get_origin(field_type)
|
37
|
-
if origin is Union:
|
38
|
-
# For Optional/Union types, find the Literal type if it exists
|
39
|
-
literal_type = next(
|
40
|
-
(t for t in get_args(field_type) if get_origin(t) is Literal), None
|
41
|
-
)
|
42
|
-
else:
|
43
|
-
# For plain types, check if it's a Literal
|
44
|
-
literal_type = field_type if origin is Literal else None
|
45
|
-
|
46
|
-
# Skip if no Literal type found
|
47
|
-
if literal_type is None:
|
48
|
-
continue
|
49
|
-
|
50
|
-
value = kwargs.get(field_name)
|
51
|
-
if value is not None:
|
52
|
-
valid_values = set(get_args(literal_type))
|
53
|
-
if value not in valid_values:
|
54
|
-
errors[field_name] = (
|
55
|
-
f"{field_name}: {colors.yellow(value)} is not a valid value"
|
56
|
-
f"\n → Valid values are: {colors.green(', '.join(sorted(valid_values)))}"
|
57
|
-
)
|
58
|
-
|
59
|
-
if errors:
|
60
|
-
message = "\n "
|
61
|
-
for _, error in errors.items():
|
62
|
-
message += error + "\n "
|
63
|
-
raise FieldValidationError(message)
|
lamindb/core/_data.py
DELETED
@@ -1,491 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from collections import defaultdict
|
4
|
-
from typing import TYPE_CHECKING
|
5
|
-
|
6
|
-
from django.db import connections
|
7
|
-
from lamin_utils import colors, logger
|
8
|
-
from lamindb_setup.core._docs import doc_args
|
9
|
-
|
10
|
-
from lamindb._query_set import QuerySet
|
11
|
-
from lamindb.core._settings import settings
|
12
|
-
from lamindb.models import (
|
13
|
-
Artifact,
|
14
|
-
Collection,
|
15
|
-
Feature,
|
16
|
-
Record,
|
17
|
-
Run,
|
18
|
-
Schema,
|
19
|
-
ULabel,
|
20
|
-
format_field_value,
|
21
|
-
record_repr,
|
22
|
-
)
|
23
|
-
|
24
|
-
from .._tracked import get_current_tracked_run
|
25
|
-
from ..errors import ValidationError
|
26
|
-
from ._context import context
|
27
|
-
from ._django import get_artifact_with_related, get_related_model
|
28
|
-
from ._feature_manager import (
|
29
|
-
add_label_feature_links,
|
30
|
-
get_host_id_field,
|
31
|
-
get_label_links,
|
32
|
-
)
|
33
|
-
from .relations import (
|
34
|
-
dict_module_name_to_model_name,
|
35
|
-
dict_related_model_to_related_name,
|
36
|
-
)
|
37
|
-
|
38
|
-
if TYPE_CHECKING:
|
39
|
-
from collections.abc import Iterable
|
40
|
-
|
41
|
-
from lamindb.base.types import StrField
|
42
|
-
|
43
|
-
|
44
|
-
WARNING_RUN_TRANSFORM = "no run & transform got linked, call `ln.track()` & re-run"
|
45
|
-
|
46
|
-
WARNING_NO_INPUT = "run input wasn't tracked, call `ln.track()` and re-run"
|
47
|
-
|
48
|
-
|
49
|
-
# also see current_run() in core._data
|
50
|
-
def get_run(run: Run | None) -> Run | None:
|
51
|
-
if run is None:
|
52
|
-
run = get_current_tracked_run()
|
53
|
-
if run is None:
|
54
|
-
run = context.run
|
55
|
-
if run is None and not settings.creation.artifact_silence_missing_run_warning:
|
56
|
-
logger.warning(WARNING_RUN_TRANSFORM)
|
57
|
-
# suppress run by passing False
|
58
|
-
elif not run:
|
59
|
-
run = None
|
60
|
-
return run
|
61
|
-
|
62
|
-
|
63
|
-
def save_staged_feature_sets(self: Artifact | Collection) -> None:
|
64
|
-
if hasattr(self, "_staged_feature_sets"):
|
65
|
-
from lamindb.core._feature_manager import get_schema_by_slot_
|
66
|
-
|
67
|
-
existing_staged_feature_sets = get_schema_by_slot_(self)
|
68
|
-
saved_staged_feature_sets = {}
|
69
|
-
for key, schema in self._staged_feature_sets.items():
|
70
|
-
if isinstance(schema, Schema) and schema._state.adding:
|
71
|
-
schema.save()
|
72
|
-
saved_staged_feature_sets[key] = schema
|
73
|
-
if key in existing_staged_feature_sets:
|
74
|
-
# remove existing feature set on the same slot
|
75
|
-
self.feature_sets.remove(existing_staged_feature_sets[key])
|
76
|
-
if len(saved_staged_feature_sets) > 0:
|
77
|
-
s = "s" if len(saved_staged_feature_sets) > 1 else ""
|
78
|
-
display_schema_keys = ",".join(
|
79
|
-
f"'{key}'" for key in saved_staged_feature_sets.keys()
|
80
|
-
)
|
81
|
-
logger.save(
|
82
|
-
f"saved {len(saved_staged_feature_sets)} feature set{s} for slot{s}:"
|
83
|
-
f" {display_schema_keys}"
|
84
|
-
)
|
85
|
-
|
86
|
-
|
87
|
-
def save_schema_links(self: Artifact | Collection) -> None:
|
88
|
-
from lamindb._save import bulk_create
|
89
|
-
|
90
|
-
Data = self.__class__
|
91
|
-
if hasattr(self, "_staged_feature_sets"):
|
92
|
-
links = []
|
93
|
-
host_id_field = get_host_id_field(self)
|
94
|
-
for slot, schema in self._staged_feature_sets.items():
|
95
|
-
kwargs = {
|
96
|
-
host_id_field: self.id,
|
97
|
-
"schema_id": schema.id,
|
98
|
-
"slot": slot,
|
99
|
-
}
|
100
|
-
links.append(Data.feature_sets.through(**kwargs))
|
101
|
-
bulk_create(links, ignore_conflicts=True)
|
102
|
-
|
103
|
-
|
104
|
-
def format_provenance(self, fk_data, print_types):
|
105
|
-
type_str = lambda attr: (
|
106
|
-
f": {get_related_model(self.__class__, attr).__name__}" if print_types else ""
|
107
|
-
)
|
108
|
-
|
109
|
-
return "".join(
|
110
|
-
[
|
111
|
-
f" .{field_name}{type_str(field_name)} = {format_field_value(value.get('name'))}\n"
|
112
|
-
for field_name, value in fk_data.items()
|
113
|
-
if value.get("name")
|
114
|
-
]
|
115
|
-
)
|
116
|
-
|
117
|
-
|
118
|
-
def format_input_of_runs(self, print_types):
|
119
|
-
if self.id is not None and self.input_of_runs.exists():
|
120
|
-
values = [format_field_value(i.started_at) for i in self.input_of_runs.all()]
|
121
|
-
type_str = ": Run" if print_types else "" # type: ignore
|
122
|
-
return f" .input_of_runs{type_str} = {', '.join(values)}\n"
|
123
|
-
return ""
|
124
|
-
|
125
|
-
|
126
|
-
def _describe_postgres(self: Artifact | Collection, print_types: bool = False):
|
127
|
-
from ._describe import describe_general
|
128
|
-
from ._feature_manager import describe_features
|
129
|
-
|
130
|
-
model_name = self.__class__.__name__
|
131
|
-
msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
|
132
|
-
if self._state.db is not None and self._state.db != "default":
|
133
|
-
msg += f" {colors.italic('Database instance')}\n"
|
134
|
-
msg += f" slug: {self._state.db}\n"
|
135
|
-
|
136
|
-
if model_name == "Artifact":
|
137
|
-
result = get_artifact_with_related(
|
138
|
-
self,
|
139
|
-
include_feature_link=True,
|
140
|
-
include_fk=True,
|
141
|
-
include_m2m=True,
|
142
|
-
include_schema=True,
|
143
|
-
)
|
144
|
-
else:
|
145
|
-
result = get_artifact_with_related(self, include_fk=True, include_m2m=True)
|
146
|
-
related_data = result.get("related_data", {})
|
147
|
-
# TODO: fk_data = related_data.get("fk", {})
|
148
|
-
|
149
|
-
tree = describe_general(self)
|
150
|
-
return describe_features(
|
151
|
-
self,
|
152
|
-
tree=tree,
|
153
|
-
related_data=related_data,
|
154
|
-
with_labels=True,
|
155
|
-
print_params=hasattr(self, "kind") and self.kind == "model",
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
def _describe_sqlite(self: Artifact | Collection, print_types: bool = False):
|
160
|
-
from ._describe import describe_general
|
161
|
-
from ._feature_manager import describe_features
|
162
|
-
|
163
|
-
model_name = self.__class__.__name__
|
164
|
-
msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
|
165
|
-
if self._state.db is not None and self._state.db != "default":
|
166
|
-
msg += f" {colors.italic('Database instance')}\n"
|
167
|
-
msg += f" slug: {self._state.db}\n"
|
168
|
-
|
169
|
-
fields = self._meta.fields
|
170
|
-
direct_fields = []
|
171
|
-
foreign_key_fields = []
|
172
|
-
for f in fields:
|
173
|
-
if f.is_relation:
|
174
|
-
foreign_key_fields.append(f.name)
|
175
|
-
else:
|
176
|
-
direct_fields.append(f.name)
|
177
|
-
if not self._state.adding:
|
178
|
-
# prefetch foreign key relationships
|
179
|
-
self = (
|
180
|
-
self.__class__.objects.using(self._state.db)
|
181
|
-
.select_related(*foreign_key_fields)
|
182
|
-
.get(id=self.id)
|
183
|
-
)
|
184
|
-
# prefetch m-2-m relationships
|
185
|
-
many_to_many_fields = []
|
186
|
-
if isinstance(self, (Collection, Artifact)):
|
187
|
-
many_to_many_fields.append("input_of_runs")
|
188
|
-
if isinstance(self, Artifact):
|
189
|
-
many_to_many_fields.append("feature_sets")
|
190
|
-
self = (
|
191
|
-
self.__class__.objects.using(self._state.db)
|
192
|
-
.prefetch_related(*many_to_many_fields)
|
193
|
-
.get(id=self.id)
|
194
|
-
)
|
195
|
-
tree = describe_general(self)
|
196
|
-
return describe_features(
|
197
|
-
self,
|
198
|
-
tree=tree,
|
199
|
-
with_labels=True,
|
200
|
-
print_params=hasattr(self, "kind") and self.kind == "kind",
|
201
|
-
)
|
202
|
-
|
203
|
-
|
204
|
-
@doc_args(Artifact.describe.__doc__)
|
205
|
-
def describe(self: Artifact | Collection, print_types: bool = False):
|
206
|
-
"""{}""" # noqa: D415
|
207
|
-
from ._describe import print_rich_tree
|
208
|
-
|
209
|
-
if not self._state.adding and connections[self._state.db].vendor == "postgresql":
|
210
|
-
tree = _describe_postgres(self, print_types=print_types)
|
211
|
-
else:
|
212
|
-
tree = _describe_sqlite(self, print_types=print_types)
|
213
|
-
|
214
|
-
print_rich_tree(tree)
|
215
|
-
|
216
|
-
|
217
|
-
def validate_feature(feature: Feature, records: list[Record]) -> None:
|
218
|
-
"""Validate feature record, adjust feature.dtype based on labels records."""
|
219
|
-
if not isinstance(feature, Feature):
|
220
|
-
raise TypeError("feature has to be of type Feature")
|
221
|
-
if feature._state.adding:
|
222
|
-
registries = {record.__class__.__get_name_with_module__() for record in records}
|
223
|
-
registries_str = "|".join(registries)
|
224
|
-
msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
|
225
|
-
raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
|
226
|
-
|
227
|
-
|
228
|
-
def get_labels(
|
229
|
-
self,
|
230
|
-
feature: Feature,
|
231
|
-
mute: bool = False,
|
232
|
-
flat_names: bool = False,
|
233
|
-
) -> QuerySet | dict[str, QuerySet] | list:
|
234
|
-
"""{}""" # noqa: D415
|
235
|
-
if not isinstance(feature, Feature):
|
236
|
-
raise TypeError("feature has to be of type Feature")
|
237
|
-
if feature.dtype is None or not feature.dtype.startswith("cat["):
|
238
|
-
raise ValueError("feature does not have linked labels")
|
239
|
-
registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
|
240
|
-
if len(registries_to_check) > 1 and not mute:
|
241
|
-
logger.warning("labels come from multiple registries!")
|
242
|
-
# return an empty query set if self.id is still None
|
243
|
-
if self.id is None:
|
244
|
-
return QuerySet(self.__class__)
|
245
|
-
qs_by_registry = {}
|
246
|
-
for registry in registries_to_check:
|
247
|
-
# currently need to distinguish between ULabel and non-ULabel, because
|
248
|
-
# we only have the feature information for Label
|
249
|
-
if registry == "ULabel":
|
250
|
-
links_to_labels = get_label_links(self, registry, feature)
|
251
|
-
label_ids = [link.ulabel_id for link in links_to_labels]
|
252
|
-
qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
|
253
|
-
id__in=label_ids
|
254
|
-
)
|
255
|
-
elif registry in self.features._accessor_by_registry:
|
256
|
-
qs_by_registry[registry] = getattr(
|
257
|
-
self, self.features._accessor_by_registry[registry]
|
258
|
-
).all()
|
259
|
-
if flat_names:
|
260
|
-
# returns a flat list of names
|
261
|
-
from lamindb._record import get_name_field
|
262
|
-
|
263
|
-
values = []
|
264
|
-
for v in qs_by_registry.values():
|
265
|
-
values += v.list(get_name_field(v))
|
266
|
-
return values
|
267
|
-
if len(registries_to_check) == 1 and registry in qs_by_registry:
|
268
|
-
return qs_by_registry[registry]
|
269
|
-
else:
|
270
|
-
return qs_by_registry
|
271
|
-
|
272
|
-
|
273
|
-
def add_labels(
|
274
|
-
self,
|
275
|
-
records: Record | list[Record] | QuerySet | Iterable,
|
276
|
-
feature: Feature | None = None,
|
277
|
-
*,
|
278
|
-
field: StrField | None = None,
|
279
|
-
feature_ref_is_name: bool | None = None,
|
280
|
-
label_ref_is_name: bool | None = None,
|
281
|
-
from_curator: bool = False,
|
282
|
-
) -> None:
|
283
|
-
"""{}""" # noqa: D415
|
284
|
-
if self._state.adding:
|
285
|
-
raise ValueError("Please save the artifact/collection before adding a label!")
|
286
|
-
|
287
|
-
if isinstance(records, (QuerySet, QuerySet.__base__)): # need to have both
|
288
|
-
records = records.list()
|
289
|
-
if isinstance(records, (str, Record)):
|
290
|
-
records = [records]
|
291
|
-
if not isinstance(records, list): # avoids warning for pd Series
|
292
|
-
records = list(records)
|
293
|
-
# create records from values
|
294
|
-
if len(records) == 0:
|
295
|
-
return None
|
296
|
-
if isinstance(records[0], str): # type: ignore
|
297
|
-
records_validated = []
|
298
|
-
# feature is needed if we want to create records from values
|
299
|
-
if feature is None:
|
300
|
-
raise ValueError(
|
301
|
-
"Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
|
302
|
-
" feature=ln.Feature(name='my_feature'))"
|
303
|
-
)
|
304
|
-
if feature.dtype.startswith("cat["):
|
305
|
-
orm_dict = dict_module_name_to_model_name(Artifact)
|
306
|
-
for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
|
307
|
-
registry = orm_dict.get(reg)
|
308
|
-
records_validated += registry.from_values(records, field=field)
|
309
|
-
|
310
|
-
# feature doesn't have registries and therefore can't create records from values
|
311
|
-
# ask users to pass records
|
312
|
-
if len(records_validated) == 0:
|
313
|
-
raise ValueError(
|
314
|
-
"Please pass a record (a `Record` object), not a string, e.g., via:"
|
315
|
-
" label"
|
316
|
-
f" = ln.ULabel(name='{records[0]}')" # type: ignore
|
317
|
-
)
|
318
|
-
records = records_validated
|
319
|
-
|
320
|
-
for record in records:
|
321
|
-
if record._state.adding:
|
322
|
-
raise ValidationError(
|
323
|
-
f"{record} not validated. If it looks correct: record.save()"
|
324
|
-
)
|
325
|
-
|
326
|
-
if feature is None:
|
327
|
-
d = dict_related_model_to_related_name(self.__class__)
|
328
|
-
# strategy: group records by registry to reduce number of transactions
|
329
|
-
records_by_related_name: dict = {}
|
330
|
-
for record in records:
|
331
|
-
related_name = d.get(record.__class__.__get_name_with_module__())
|
332
|
-
if related_name is None:
|
333
|
-
raise ValueError(f"Can't add labels to {record.__class__} record!")
|
334
|
-
if related_name not in records_by_related_name:
|
335
|
-
records_by_related_name[related_name] = []
|
336
|
-
records_by_related_name[related_name].append(record)
|
337
|
-
for related_name, records in records_by_related_name.items():
|
338
|
-
getattr(self, related_name).add(*records)
|
339
|
-
else:
|
340
|
-
validate_feature(feature, records) # type:ignore
|
341
|
-
records_by_registry = defaultdict(list)
|
342
|
-
feature_sets = self.feature_sets.filter(itype="Feature").all()
|
343
|
-
internal_features = set() # type: ignore
|
344
|
-
if len(feature_sets) > 0:
|
345
|
-
for schema in feature_sets:
|
346
|
-
internal_features = internal_features.union(
|
347
|
-
set(schema.members.values_list("name", flat=True))
|
348
|
-
) # type: ignore
|
349
|
-
for record in records:
|
350
|
-
records_by_registry[record.__class__.__get_name_with_module__()].append(
|
351
|
-
record
|
352
|
-
)
|
353
|
-
for registry_name, records in records_by_registry.items():
|
354
|
-
if not from_curator and feature.name in internal_features:
|
355
|
-
raise ValidationError(
|
356
|
-
"Cannot manually annotate internal feature with label. Please use ln.Curator"
|
357
|
-
)
|
358
|
-
if registry_name not in feature.dtype:
|
359
|
-
if not feature.dtype.startswith("cat"):
|
360
|
-
raise ValidationError(
|
361
|
-
f"Feature {feature.name} needs dtype='cat' for label annotation, currently has dtype='{feature.dtype}'"
|
362
|
-
)
|
363
|
-
if feature.dtype == "cat":
|
364
|
-
feature.dtype = f"cat[{registry_name}]" # type: ignore
|
365
|
-
feature.save()
|
366
|
-
elif registry_name not in feature.dtype:
|
367
|
-
new_dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
|
368
|
-
raise ValidationError(
|
369
|
-
f"Label type {registry_name} is not valid for Feature(name='{feature.name}', dtype='{feature.dtype}'), consider updating to dtype='{new_dtype}'"
|
370
|
-
)
|
371
|
-
|
372
|
-
if registry_name not in self.features._accessor_by_registry:
|
373
|
-
logger.warning(f"skipping {registry_name}")
|
374
|
-
continue
|
375
|
-
if len(records) == 0:
|
376
|
-
continue
|
377
|
-
features_labels = {
|
378
|
-
registry_name: [(feature, label_record) for label_record in records]
|
379
|
-
}
|
380
|
-
add_label_feature_links(
|
381
|
-
self.features,
|
382
|
-
features_labels,
|
383
|
-
feature_ref_is_name=feature_ref_is_name,
|
384
|
-
label_ref_is_name=label_ref_is_name,
|
385
|
-
)
|
386
|
-
|
387
|
-
|
388
|
-
def _track_run_input(
|
389
|
-
data: Artifact | Collection | Iterable[Artifact] | Iterable[Collection],
|
390
|
-
is_run_input: bool | Run | None = None,
|
391
|
-
run: Run | None = None,
|
392
|
-
):
|
393
|
-
if isinstance(is_run_input, Run):
|
394
|
-
run = is_run_input
|
395
|
-
is_run_input = True
|
396
|
-
elif run is None:
|
397
|
-
run = get_current_tracked_run()
|
398
|
-
if run is None:
|
399
|
-
run = context.run
|
400
|
-
# consider that data is an iterable of Data
|
401
|
-
data_iter: Iterable[Artifact] | Iterable[Collection] = (
|
402
|
-
[data] if isinstance(data, (Artifact, Collection)) else data
|
403
|
-
)
|
404
|
-
track_run_input = False
|
405
|
-
input_data = []
|
406
|
-
if run is not None:
|
407
|
-
# avoid cycles: data can't be both input and output
|
408
|
-
def is_valid_input(data: Artifact | Collection):
|
409
|
-
is_valid = False
|
410
|
-
if data._state.db == "default":
|
411
|
-
# things are OK if the record is on the default db
|
412
|
-
is_valid = True
|
413
|
-
elif data._state.db is None:
|
414
|
-
# if a record is not yet saved, it can't be an input
|
415
|
-
# we silently ignore because what likely happens is that
|
416
|
-
# the user works with an object that's about to be saved
|
417
|
-
# in the current Python session
|
418
|
-
is_valid = False
|
419
|
-
else:
|
420
|
-
# record is on another db
|
421
|
-
# we have to save the record into the current db with
|
422
|
-
# the run being attached to a transfer transform
|
423
|
-
logger.important(
|
424
|
-
f"completing transfer to track {data.__class__.__name__}('{data.uid[:8]}') as input"
|
425
|
-
)
|
426
|
-
data.save()
|
427
|
-
is_valid = True
|
428
|
-
return (
|
429
|
-
data.run_id != run.id
|
430
|
-
and not data._state.adding # this seems duplicated with data._state.db is None
|
431
|
-
and is_valid
|
432
|
-
)
|
433
|
-
|
434
|
-
input_data = [data for data in data_iter if is_valid_input(data)]
|
435
|
-
input_data_ids = [data.id for data in input_data]
|
436
|
-
if input_data:
|
437
|
-
data_class_name = input_data[0].__class__.__name__.lower()
|
438
|
-
# let us first look at the case in which the user does not
|
439
|
-
# provide a boolean value for `is_run_input`
|
440
|
-
# hence, we need to determine whether we actually want to
|
441
|
-
# track a run or not
|
442
|
-
if is_run_input is None:
|
443
|
-
# we don't have a run record
|
444
|
-
if run is None:
|
445
|
-
if settings.track_run_inputs:
|
446
|
-
logger.warning(WARNING_NO_INPUT)
|
447
|
-
# assume we have a run record
|
448
|
-
else:
|
449
|
-
# assume there is non-cyclic candidate input data
|
450
|
-
if input_data:
|
451
|
-
if settings.track_run_inputs:
|
452
|
-
transform_note = ""
|
453
|
-
if len(input_data) == 1:
|
454
|
-
if input_data[0].transform is not None:
|
455
|
-
transform_note = (
|
456
|
-
", adding parent transform"
|
457
|
-
f" {input_data[0].transform.id}"
|
458
|
-
)
|
459
|
-
logger.info(
|
460
|
-
f"adding {data_class_name} ids {input_data_ids} as inputs for run"
|
461
|
-
f" {run.id}{transform_note}"
|
462
|
-
)
|
463
|
-
track_run_input = True
|
464
|
-
else:
|
465
|
-
logger.hint(
|
466
|
-
"track these data as a run input by passing `is_run_input=True`"
|
467
|
-
)
|
468
|
-
else:
|
469
|
-
track_run_input = is_run_input
|
470
|
-
if track_run_input:
|
471
|
-
if run is None:
|
472
|
-
raise ValueError("No run context set. Call `ln.track()`.")
|
473
|
-
# avoid adding the same run twice
|
474
|
-
run.save()
|
475
|
-
if data_class_name == "artifact":
|
476
|
-
LinkORM = run.input_artifacts.through
|
477
|
-
links = [
|
478
|
-
LinkORM(run_id=run.id, artifact_id=data_id)
|
479
|
-
for data_id in input_data_ids
|
480
|
-
]
|
481
|
-
else:
|
482
|
-
LinkORM = run.input_collections.through
|
483
|
-
links = [
|
484
|
-
LinkORM(run_id=run.id, collection_id=data_id)
|
485
|
-
for data_id in input_data_ids
|
486
|
-
]
|
487
|
-
LinkORM.objects.bulk_create(links, ignore_conflicts=True)
|
488
|
-
# generalize below for more than one data batch
|
489
|
-
if len(input_data) == 1:
|
490
|
-
if input_data[0].transform is not None:
|
491
|
-
run.transform.predecessors.add(input_data[0].transform)
|