lamindb 0.76.7__py3-none-any.whl → 0.76.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. lamindb/__init__.py +113 -113
  2. lamindb/_artifact.py +1205 -1178
  3. lamindb/_can_validate.py +579 -579
  4. lamindb/_collection.py +387 -387
  5. lamindb/_curate.py +1601 -1601
  6. lamindb/_feature.py +155 -155
  7. lamindb/_feature_set.py +242 -242
  8. lamindb/_filter.py +23 -23
  9. lamindb/_finish.py +256 -256
  10. lamindb/_from_values.py +382 -382
  11. lamindb/_is_versioned.py +40 -40
  12. lamindb/_parents.py +476 -476
  13. lamindb/_query_manager.py +125 -125
  14. lamindb/_query_set.py +362 -362
  15. lamindb/_record.py +649 -649
  16. lamindb/_run.py +57 -57
  17. lamindb/_save.py +308 -295
  18. lamindb/_storage.py +14 -14
  19. lamindb/_transform.py +127 -127
  20. lamindb/_ulabel.py +56 -56
  21. lamindb/_utils.py +9 -9
  22. lamindb/_view.py +72 -72
  23. lamindb/core/__init__.py +94 -94
  24. lamindb/core/_context.py +574 -574
  25. lamindb/core/_data.py +438 -438
  26. lamindb/core/_feature_manager.py +867 -867
  27. lamindb/core/_label_manager.py +253 -253
  28. lamindb/core/_mapped_collection.py +597 -597
  29. lamindb/core/_settings.py +187 -187
  30. lamindb/core/_sync_git.py +138 -138
  31. lamindb/core/_track_environment.py +27 -27
  32. lamindb/core/datasets/__init__.py +59 -59
  33. lamindb/core/datasets/_core.py +571 -571
  34. lamindb/core/datasets/_fake.py +36 -36
  35. lamindb/core/exceptions.py +90 -77
  36. lamindb/core/fields.py +12 -12
  37. lamindb/core/loaders.py +164 -164
  38. lamindb/core/schema.py +56 -56
  39. lamindb/core/storage/__init__.py +25 -25
  40. lamindb/core/storage/_anndata_accessor.py +740 -740
  41. lamindb/core/storage/_anndata_sizes.py +41 -41
  42. lamindb/core/storage/_backed_access.py +98 -98
  43. lamindb/core/storage/_tiledbsoma.py +204 -204
  44. lamindb/core/storage/_valid_suffixes.py +21 -21
  45. lamindb/core/storage/_zarr.py +110 -110
  46. lamindb/core/storage/objects.py +62 -62
  47. lamindb/core/storage/paths.py +172 -141
  48. lamindb/core/subsettings/__init__.py +12 -12
  49. lamindb/core/subsettings/_creation_settings.py +38 -38
  50. lamindb/core/subsettings/_transform_settings.py +21 -21
  51. lamindb/core/types.py +19 -19
  52. lamindb/core/versioning.py +158 -158
  53. lamindb/integrations/__init__.py +12 -12
  54. lamindb/integrations/_vitessce.py +107 -107
  55. lamindb/setup/__init__.py +14 -14
  56. lamindb/setup/core/__init__.py +4 -4
  57. {lamindb-0.76.7.dist-info → lamindb-0.76.8.dist-info}/LICENSE +201 -201
  58. {lamindb-0.76.7.dist-info → lamindb-0.76.8.dist-info}/METADATA +3 -3
  59. lamindb-0.76.8.dist-info/RECORD +60 -0
  60. {lamindb-0.76.7.dist-info → lamindb-0.76.8.dist-info}/WHEEL +1 -1
  61. lamindb-0.76.7.dist-info/RECORD +0 -60
lamindb/core/_data.py CHANGED
@@ -1,438 +1,438 @@
1
- from __future__ import annotations
2
-
3
- from collections import defaultdict
4
- from typing import TYPE_CHECKING, Any, Iterable, List
5
-
6
- from lamin_utils import colors, logger
7
- from lamindb_setup.core._docs import doc_args
8
- from lnschema_core.models import (
9
- Artifact,
10
- Collection,
11
- Feature,
12
- FeatureSet,
13
- Record,
14
- Run,
15
- ULabel,
16
- format_field_value,
17
- record_repr,
18
- )
19
-
20
- from lamindb._parents import view_lineage
21
- from lamindb._query_set import QuerySet
22
- from lamindb._record import get_name_field
23
- from lamindb.core._settings import settings
24
-
25
- from ._context import context
26
- from ._feature_manager import (
27
- get_feature_set_links,
28
- get_host_id_field,
29
- get_label_links,
30
- print_features,
31
- )
32
- from ._label_manager import print_labels
33
- from .exceptions import ValidationError
34
- from .schema import (
35
- dict_related_model_to_related_name,
36
- dict_schema_name_to_model_name,
37
- )
38
-
39
- if TYPE_CHECKING:
40
- from lnschema_core.types import StrField
41
-
42
-
43
- WARNING_RUN_TRANSFORM = (
44
- "no run & transform got linked, call `ln.context.track()` & re-run`"
45
- )
46
-
47
- WARNING_NO_INPUT = "run input wasn't tracked, call `ln.context.track()` and re-run"
48
-
49
-
50
- def get_run(run: Run | None) -> Run | None:
51
- if run is None:
52
- run = context.run
53
- if run is None and not settings.creation.artifact_silence_missing_run_warning:
54
- logger.warning(WARNING_RUN_TRANSFORM)
55
- # suppress run by passing False
56
- elif not run:
57
- run = None
58
- return run
59
-
60
-
61
- def add_transform_to_kwargs(kwargs: dict[str, Any], run: Run):
62
- if run is not None:
63
- kwargs["transform"] = run.transform
64
-
65
-
66
- def save_feature_sets(self: Artifact | Collection) -> None:
67
- if hasattr(self, "_feature_sets"):
68
- saved_feature_sets = {}
69
- for key, feature_set in self._feature_sets.items():
70
- if isinstance(feature_set, FeatureSet) and feature_set._state.adding:
71
- feature_set.save()
72
- saved_feature_sets[key] = feature_set
73
- if len(saved_feature_sets) > 0:
74
- s = "s" if len(saved_feature_sets) > 1 else ""
75
- display_feature_set_keys = ",".join(
76
- f"'{key}'" for key in saved_feature_sets.keys()
77
- )
78
- logger.save(
79
- f"saved {len(saved_feature_sets)} feature set{s} for slot{s}:"
80
- f" {display_feature_set_keys}"
81
- )
82
-
83
-
84
- def save_feature_set_links(self: Artifact | Collection) -> None:
85
- from lamindb._save import bulk_create
86
-
87
- Data = self.__class__
88
- if hasattr(self, "_feature_sets"):
89
- links = []
90
- host_id_field = get_host_id_field(self)
91
- for slot, feature_set in self._feature_sets.items():
92
- kwargs = {
93
- host_id_field: self.id,
94
- "featureset_id": feature_set.id,
95
- "slot": slot,
96
- }
97
- links.append(Data.feature_sets.through(**kwargs))
98
- bulk_create(links, ignore_conflicts=True)
99
-
100
-
101
- @doc_args(Artifact.describe.__doc__)
102
- def describe(self: Artifact, print_types: bool = False):
103
- """{}""" # noqa: D415
104
- model_name = self.__class__.__name__
105
- msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
106
- if self._state.db is not None and self._state.db != "default":
107
- msg += f" {colors.italic('Database instance')}\n"
108
- msg += f" slug: {self._state.db}\n"
109
- # prefetch all many-to-many relationships
110
- # doesn't work for describing using artifact
111
- # self = (
112
- # self.__class__.objects.using(self._state.db)
113
- # .prefetch_related(
114
- # *[f.name for f in self.__class__._meta.get_fields() if f.many_to_many]
115
- # )
116
- # .get(id=self.id)
117
- # )
118
-
119
- prov_msg = ""
120
- fields = self._meta.fields
121
- direct_fields = []
122
- foreign_key_fields = []
123
- for f in fields:
124
- if f.is_relation:
125
- foreign_key_fields.append(f.name)
126
- else:
127
- direct_fields.append(f.name)
128
- if not self._state.adding:
129
- # prefetch foreign key relationships
130
- self = (
131
- self.__class__.objects.using(self._state.db)
132
- .select_related(*foreign_key_fields)
133
- .get(id=self.id)
134
- )
135
- # prefetch m-2-m relationships
136
- many_to_many_fields = []
137
- if isinstance(self, (Collection, Artifact)):
138
- many_to_many_fields.append("input_of_runs")
139
- if isinstance(self, Artifact):
140
- many_to_many_fields.append("feature_sets")
141
- self = (
142
- self.__class__.objects.using(self._state.db)
143
- .prefetch_related(*many_to_many_fields)
144
- .get(id=self.id)
145
- )
146
-
147
- # provenance
148
- if len(foreign_key_fields) > 0: # always True for Artifact and Collection
149
- fields_values = [(field, getattr(self, field)) for field in foreign_key_fields]
150
- type_str = lambda attr: (
151
- f": {attr.__class__.__get_name_with_schema__()}" if print_types else ""
152
- )
153
- related_msg = "".join(
154
- [
155
- f" .{field_name}{type_str(attr)} = {format_field_value(getattr(attr, get_name_field(attr)))}\n"
156
- for (field_name, attr) in fields_values
157
- if attr is not None
158
- ]
159
- )
160
- prov_msg += related_msg
161
- if prov_msg:
162
- msg += f" {colors.italic('Provenance')}\n"
163
- msg += prov_msg
164
-
165
- # input of runs
166
- input_of_message = ""
167
- if self.id is not None and self.input_of_runs.exists():
168
- values = [format_field_value(i.started_at) for i in self.input_of_runs.all()]
169
- type_str = ": Run" if print_types else "" # type: ignore
170
- input_of_message += f" .input_of_runs{type_str} = {', '.join(values)}\n"
171
- if input_of_message:
172
- msg += f" {colors.italic('Usage')}\n"
173
- msg += input_of_message
174
-
175
- # labels
176
- msg += print_labels(self, print_types=print_types)
177
-
178
- # features
179
- if isinstance(self, Artifact):
180
- msg += print_features( # type: ignore
181
- self,
182
- print_types=print_types,
183
- print_params=hasattr(self, "type") and self.type == "model",
184
- )
185
-
186
- # print entire message
187
- logger.print(msg)
188
-
189
-
190
- def validate_feature(feature: Feature, records: list[Record]) -> None:
191
- """Validate feature record, adjust feature.dtype based on labels records."""
192
- if not isinstance(feature, Feature):
193
- raise TypeError("feature has to be of type Feature")
194
- if feature._state.adding:
195
- registries = {record.__class__.__get_name_with_schema__() for record in records}
196
- registries_str = "|".join(registries)
197
- msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
198
- raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
199
-
200
-
201
- def get_labels(
202
- self,
203
- feature: Feature,
204
- mute: bool = False,
205
- flat_names: bool = False,
206
- ) -> QuerySet | dict[str, QuerySet] | list:
207
- """{}""" # noqa: D415
208
- if not isinstance(feature, Feature):
209
- raise TypeError("feature has to be of type Feature")
210
- if feature.dtype is None or not feature.dtype.startswith("cat["):
211
- raise ValueError("feature does not have linked labels")
212
- registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
213
- if len(registries_to_check) > 1 and not mute:
214
- logger.warning("labels come from multiple registries!")
215
- # return an empty query set if self.id is still None
216
- if self.id is None:
217
- return QuerySet(self.__class__)
218
- qs_by_registry = {}
219
- for registry in registries_to_check:
220
- # currently need to distinguish between ULabel and non-ULabel, because
221
- # we only have the feature information for Label
222
- if registry == "ULabel":
223
- links_to_labels = get_label_links(self, registry, feature)
224
- label_ids = [link.ulabel_id for link in links_to_labels]
225
- qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
226
- id__in=label_ids
227
- )
228
- elif registry in self.features._accessor_by_registry:
229
- qs_by_registry[registry] = getattr(
230
- self, self.features._accessor_by_registry[registry]
231
- ).all()
232
- if flat_names:
233
- # returns a flat list of names
234
- from lamindb._record import get_name_field
235
-
236
- values = []
237
- for v in qs_by_registry.values():
238
- values += v.list(get_name_field(v))
239
- return values
240
- if len(registries_to_check) == 1 and registry in qs_by_registry:
241
- return qs_by_registry[registry]
242
- else:
243
- return qs_by_registry
244
-
245
-
246
- def add_labels(
247
- self,
248
- records: Record | list[Record] | QuerySet | Iterable,
249
- feature: Feature | None = None,
250
- *,
251
- field: StrField | None = None,
252
- ) -> None:
253
- """{}""" # noqa: D415
254
- if self._state.adding:
255
- raise ValueError("Please save the artifact/collection before adding a label!")
256
-
257
- if isinstance(records, (QuerySet, QuerySet.__base__)): # need to have both
258
- records = records.list()
259
- if isinstance(records, (str, Record)):
260
- records = [records]
261
- if not isinstance(records, List): # avoids warning for pd Series
262
- records = list(records)
263
- # create records from values
264
- if len(records) == 0:
265
- return None
266
- if isinstance(records[0], str): # type: ignore
267
- records_validated = []
268
- # feature is needed if we want to create records from values
269
- if feature is None:
270
- raise ValueError(
271
- "Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
272
- " feature=ln.Feature(name='my_feature'))"
273
- )
274
- if feature.dtype.startswith("cat["):
275
- orm_dict = dict_schema_name_to_model_name(Artifact)
276
- for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
277
- registry = orm_dict.get(reg)
278
- records_validated += registry.from_values(records, field=field)
279
-
280
- # feature doesn't have registries and therefore can't create records from values
281
- # ask users to pass records
282
- if len(records_validated) == 0:
283
- raise ValueError(
284
- "Please pass a record (a `Record` object), not a string, e.g., via:"
285
- " label"
286
- f" = ln.ULabel(name='{records[0]}')" # type: ignore
287
- )
288
- records = records_validated
289
-
290
- for record in records:
291
- if record._state.adding:
292
- raise ValidationError(
293
- f"{record} not validated. If it looks correct: record.save()"
294
- )
295
-
296
- if feature is None:
297
- d = dict_related_model_to_related_name(self.__class__)
298
- # strategy: group records by registry to reduce number of transactions
299
- records_by_related_name: dict = {}
300
- for record in records:
301
- related_name = d.get(record.__class__.__get_name_with_schema__())
302
- if related_name is None:
303
- raise ValueError(f"Can't add labels to {record.__class__} record!")
304
- if related_name not in records_by_related_name:
305
- records_by_related_name[related_name] = []
306
- records_by_related_name[related_name].append(record)
307
- for related_name, records in records_by_related_name.items():
308
- getattr(self, related_name).add(*records)
309
- else:
310
- validate_feature(feature, records) # type:ignore
311
- records_by_registry = defaultdict(list)
312
- for record in records:
313
- records_by_registry[record.__class__.__get_name_with_schema__()].append(
314
- record
315
- )
316
- for registry_name, records in records_by_registry.items():
317
- if registry_name not in self.features._accessor_by_registry:
318
- logger.warning(f"skipping {registry_name}")
319
- continue
320
- labels_accessor = getattr(
321
- self, self.features._accessor_by_registry[registry_name]
322
- )
323
- # remove labels that are already linked as add doesn't perform update
324
- linked_labels = [r for r in records if r in labels_accessor.filter()]
325
- if len(linked_labels) > 0:
326
- labels_accessor.remove(*linked_labels)
327
- labels_accessor.add(*records, through_defaults={"feature_id": feature.id})
328
- links_feature_set = get_feature_set_links(self)
329
- feature_set_ids = [link.featureset_id for link in links_feature_set.all()]
330
- # get all linked features of type Feature
331
- feature_sets = FeatureSet.filter(id__in=feature_set_ids).all()
332
- {
333
- links_feature_set.filter(featureset_id=feature_set.id)
334
- .one()
335
- .slot: feature_set.features.all()
336
- for feature_set in feature_sets
337
- if "Feature" == feature_set.registry
338
- }
339
- for registry_name, _ in records_by_registry.items():
340
- if registry_name not in feature.dtype:
341
- logger.debug(
342
- f"updated categorical feature '{feature.name}' type with registry '{registry_name}'"
343
- )
344
- if not feature.dtype.startswith("cat["):
345
- feature.dtype = f"cat[{registry_name}]"
346
- elif registry_name not in feature.dtype:
347
- feature.dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
348
- feature.save()
349
-
350
-
351
- def _track_run_input(
352
- data: Artifact | Collection | Iterable[Artifact] | Iterable[Collection],
353
- is_run_input: bool | None = None,
354
- run: Run | None = None,
355
- ):
356
- # this is an internal hack right now for project-flow, but we can allow this
357
- # for the user in the future
358
- if isinstance(is_run_input, Run):
359
- run = is_run_input
360
- is_run_input = True
361
- elif run is None:
362
- run = context.run
363
- # consider that data is an iterable of Data
364
- data_iter: Iterable[Artifact] | Iterable[Collection] = (
365
- [data] if isinstance(data, (Artifact, Collection)) else data
366
- )
367
- track_run_input = False
368
- input_data = []
369
- if run is not None:
370
- # avoid cycles: data can't be both input and output
371
- def is_valid_input(data: Artifact | Collection):
372
- return (
373
- data.run_id != run.id
374
- and not data._state.adding
375
- and data._state.db in {"default", None}
376
- )
377
-
378
- input_data = [data for data in data_iter if is_valid_input(data)]
379
- input_data_ids = [data.id for data in input_data]
380
- if input_data:
381
- data_class_name = input_data[0].__class__.__name__.lower()
382
- # let us first look at the case in which the user does not
383
- # provide a boolean value for `is_run_input`
384
- # hence, we need to determine whether we actually want to
385
- # track a run or not
386
- if is_run_input is None:
387
- # we don't have a run record
388
- if run is None:
389
- if settings.track_run_inputs:
390
- logger.warning(WARNING_NO_INPUT)
391
- # assume we have a run record
392
- else:
393
- # assume there is non-cyclic candidate input data
394
- if input_data:
395
- if settings.track_run_inputs:
396
- transform_note = ""
397
- if len(input_data) == 1:
398
- if input_data[0].transform is not None:
399
- transform_note = (
400
- ", adding parent transform"
401
- f" {input_data[0].transform.id}"
402
- )
403
- logger.info(
404
- f"adding {data_class_name} ids {input_data_ids} as inputs for run"
405
- f" {run.id}{transform_note}"
406
- )
407
- track_run_input = True
408
- else:
409
- logger.hint(
410
- "track these data as a run input by passing `is_run_input=True`"
411
- )
412
- else:
413
- track_run_input = is_run_input
414
- if track_run_input:
415
- if run is None:
416
- raise ValueError(
417
- "No run context set. Call ln.context.track() or link input to a"
418
- " run object via `run.input_artifacts.add(artifact)`"
419
- )
420
- # avoid adding the same run twice
421
- run.save()
422
- if data_class_name == "artifact":
423
- LinkORM = run.input_artifacts.through
424
- links = [
425
- LinkORM(run_id=run.id, artifact_id=data_id)
426
- for data_id in input_data_ids
427
- ]
428
- else:
429
- LinkORM = run.input_collections.through
430
- links = [
431
- LinkORM(run_id=run.id, collection_id=data_id)
432
- for data_id in input_data_ids
433
- ]
434
- LinkORM.objects.bulk_create(links, ignore_conflicts=True)
435
- # generalize below for more than one data batch
436
- if len(input_data) == 1:
437
- if input_data[0].transform is not None:
438
- run.transform.predecessors.add(input_data[0].transform)
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import TYPE_CHECKING, Any, Iterable, List
5
+
6
+ from lamin_utils import colors, logger
7
+ from lamindb_setup.core._docs import doc_args
8
+ from lnschema_core.models import (
9
+ Artifact,
10
+ Collection,
11
+ Feature,
12
+ FeatureSet,
13
+ Record,
14
+ Run,
15
+ ULabel,
16
+ format_field_value,
17
+ record_repr,
18
+ )
19
+
20
+ from lamindb._parents import view_lineage
21
+ from lamindb._query_set import QuerySet
22
+ from lamindb._record import get_name_field
23
+ from lamindb.core._settings import settings
24
+
25
+ from ._context import context
26
+ from ._feature_manager import (
27
+ get_feature_set_links,
28
+ get_host_id_field,
29
+ get_label_links,
30
+ print_features,
31
+ )
32
+ from ._label_manager import print_labels
33
+ from .exceptions import ValidationError
34
+ from .schema import (
35
+ dict_related_model_to_related_name,
36
+ dict_schema_name_to_model_name,
37
+ )
38
+
39
+ if TYPE_CHECKING:
40
+ from lnschema_core.types import StrField
41
+
42
+
43
+ WARNING_RUN_TRANSFORM = (
44
+ "no run & transform got linked, call `ln.context.track()` & re-run"
45
+ )
46
+
47
+ WARNING_NO_INPUT = "run input wasn't tracked, call `ln.context.track()` and re-run"
48
+
49
+
50
+ def get_run(run: Run | None) -> Run | None:
51
+ if run is None:
52
+ run = context.run
53
+ if run is None and not settings.creation.artifact_silence_missing_run_warning:
54
+ logger.warning(WARNING_RUN_TRANSFORM)
55
+ # suppress run by passing False
56
+ elif not run:
57
+ run = None
58
+ return run
59
+
60
+
61
+ def add_transform_to_kwargs(kwargs: dict[str, Any], run: Run):
62
+ if run is not None:
63
+ kwargs["transform"] = run.transform
64
+
65
+
66
+ def save_feature_sets(self: Artifact | Collection) -> None:
67
+ if hasattr(self, "_feature_sets"):
68
+ saved_feature_sets = {}
69
+ for key, feature_set in self._feature_sets.items():
70
+ if isinstance(feature_set, FeatureSet) and feature_set._state.adding:
71
+ feature_set.save()
72
+ saved_feature_sets[key] = feature_set
73
+ if len(saved_feature_sets) > 0:
74
+ s = "s" if len(saved_feature_sets) > 1 else ""
75
+ display_feature_set_keys = ",".join(
76
+ f"'{key}'" for key in saved_feature_sets.keys()
77
+ )
78
+ logger.save(
79
+ f"saved {len(saved_feature_sets)} feature set{s} for slot{s}:"
80
+ f" {display_feature_set_keys}"
81
+ )
82
+
83
+
84
+ def save_feature_set_links(self: Artifact | Collection) -> None:
85
+ from lamindb._save import bulk_create
86
+
87
+ Data = self.__class__
88
+ if hasattr(self, "_feature_sets"):
89
+ links = []
90
+ host_id_field = get_host_id_field(self)
91
+ for slot, feature_set in self._feature_sets.items():
92
+ kwargs = {
93
+ host_id_field: self.id,
94
+ "featureset_id": feature_set.id,
95
+ "slot": slot,
96
+ }
97
+ links.append(Data.feature_sets.through(**kwargs))
98
+ bulk_create(links, ignore_conflicts=True)
99
+
100
+
101
+ @doc_args(Artifact.describe.__doc__)
102
+ def describe(self: Artifact, print_types: bool = False):
103
+ """{}""" # noqa: D415
104
+ model_name = self.__class__.__name__
105
+ msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
106
+ if self._state.db is not None and self._state.db != "default":
107
+ msg += f" {colors.italic('Database instance')}\n"
108
+ msg += f" slug: {self._state.db}\n"
109
+ # prefetch all many-to-many relationships
110
+ # doesn't work for describing using artifact
111
+ # self = (
112
+ # self.__class__.objects.using(self._state.db)
113
+ # .prefetch_related(
114
+ # *[f.name for f in self.__class__._meta.get_fields() if f.many_to_many]
115
+ # )
116
+ # .get(id=self.id)
117
+ # )
118
+
119
+ prov_msg = ""
120
+ fields = self._meta.fields
121
+ direct_fields = []
122
+ foreign_key_fields = []
123
+ for f in fields:
124
+ if f.is_relation:
125
+ foreign_key_fields.append(f.name)
126
+ else:
127
+ direct_fields.append(f.name)
128
+ if not self._state.adding:
129
+ # prefetch foreign key relationships
130
+ self = (
131
+ self.__class__.objects.using(self._state.db)
132
+ .select_related(*foreign_key_fields)
133
+ .get(id=self.id)
134
+ )
135
+ # prefetch m-2-m relationships
136
+ many_to_many_fields = []
137
+ if isinstance(self, (Collection, Artifact)):
138
+ many_to_many_fields.append("input_of_runs")
139
+ if isinstance(self, Artifact):
140
+ many_to_many_fields.append("feature_sets")
141
+ self = (
142
+ self.__class__.objects.using(self._state.db)
143
+ .prefetch_related(*many_to_many_fields)
144
+ .get(id=self.id)
145
+ )
146
+
147
+ # provenance
148
+ if len(foreign_key_fields) > 0: # always True for Artifact and Collection
149
+ fields_values = [(field, getattr(self, field)) for field in foreign_key_fields]
150
+ type_str = lambda attr: (
151
+ f": {attr.__class__.__get_name_with_schema__()}" if print_types else ""
152
+ )
153
+ related_msg = "".join(
154
+ [
155
+ f" .{field_name}{type_str(attr)} = {format_field_value(getattr(attr, get_name_field(attr)))}\n"
156
+ for (field_name, attr) in fields_values
157
+ if attr is not None
158
+ ]
159
+ )
160
+ prov_msg += related_msg
161
+ if prov_msg:
162
+ msg += f" {colors.italic('Provenance')}\n"
163
+ msg += prov_msg
164
+
165
+ # input of runs
166
+ input_of_message = ""
167
+ if self.id is not None and self.input_of_runs.exists():
168
+ values = [format_field_value(i.started_at) for i in self.input_of_runs.all()]
169
+ type_str = ": Run" if print_types else "" # type: ignore
170
+ input_of_message += f" .input_of_runs{type_str} = {', '.join(values)}\n"
171
+ if input_of_message:
172
+ msg += f" {colors.italic('Usage')}\n"
173
+ msg += input_of_message
174
+
175
+ # labels
176
+ msg += print_labels(self, print_types=print_types)
177
+
178
+ # features
179
+ if isinstance(self, Artifact):
180
+ msg += print_features( # type: ignore
181
+ self,
182
+ print_types=print_types,
183
+ print_params=hasattr(self, "type") and self.type == "model",
184
+ )
185
+
186
+ # print entire message
187
+ logger.print(msg)
188
+
189
+
190
+ def validate_feature(feature: Feature, records: list[Record]) -> None:
191
+ """Validate feature record, adjust feature.dtype based on labels records."""
192
+ if not isinstance(feature, Feature):
193
+ raise TypeError("feature has to be of type Feature")
194
+ if feature._state.adding:
195
+ registries = {record.__class__.__get_name_with_schema__() for record in records}
196
+ registries_str = "|".join(registries)
197
+ msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
198
+ raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
199
+
200
+
201
+ def get_labels(
202
+ self,
203
+ feature: Feature,
204
+ mute: bool = False,
205
+ flat_names: bool = False,
206
+ ) -> QuerySet | dict[str, QuerySet] | list:
207
+ """{}""" # noqa: D415
208
+ if not isinstance(feature, Feature):
209
+ raise TypeError("feature has to be of type Feature")
210
+ if feature.dtype is None or not feature.dtype.startswith("cat["):
211
+ raise ValueError("feature does not have linked labels")
212
+ registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
213
+ if len(registries_to_check) > 1 and not mute:
214
+ logger.warning("labels come from multiple registries!")
215
+ # return an empty query set if self.id is still None
216
+ if self.id is None:
217
+ return QuerySet(self.__class__)
218
+ qs_by_registry = {}
219
+ for registry in registries_to_check:
220
+ # currently need to distinguish between ULabel and non-ULabel, because
221
+ # we only have the feature information for Label
222
+ if registry == "ULabel":
223
+ links_to_labels = get_label_links(self, registry, feature)
224
+ label_ids = [link.ulabel_id for link in links_to_labels]
225
+ qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
226
+ id__in=label_ids
227
+ )
228
+ elif registry in self.features._accessor_by_registry:
229
+ qs_by_registry[registry] = getattr(
230
+ self, self.features._accessor_by_registry[registry]
231
+ ).all()
232
+ if flat_names:
233
+ # returns a flat list of names
234
+ from lamindb._record import get_name_field
235
+
236
+ values = []
237
+ for v in qs_by_registry.values():
238
+ values += v.list(get_name_field(v))
239
+ return values
240
+ if len(registries_to_check) == 1 and registry in qs_by_registry:
241
+ return qs_by_registry[registry]
242
+ else:
243
+ return qs_by_registry
244
+
245
+
246
+ def add_labels(
247
+ self,
248
+ records: Record | list[Record] | QuerySet | Iterable,
249
+ feature: Feature | None = None,
250
+ *,
251
+ field: StrField | None = None,
252
+ ) -> None:
253
+ """{}""" # noqa: D415
254
+ if self._state.adding:
255
+ raise ValueError("Please save the artifact/collection before adding a label!")
256
+
257
+ if isinstance(records, (QuerySet, QuerySet.__base__)): # need to have both
258
+ records = records.list()
259
+ if isinstance(records, (str, Record)):
260
+ records = [records]
261
+ if not isinstance(records, List): # avoids warning for pd Series
262
+ records = list(records)
263
+ # create records from values
264
+ if len(records) == 0:
265
+ return None
266
+ if isinstance(records[0], str): # type: ignore
267
+ records_validated = []
268
+ # feature is needed if we want to create records from values
269
+ if feature is None:
270
+ raise ValueError(
271
+ "Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
272
+ " feature=ln.Feature(name='my_feature'))"
273
+ )
274
+ if feature.dtype.startswith("cat["):
275
+ orm_dict = dict_schema_name_to_model_name(Artifact)
276
+ for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
277
+ registry = orm_dict.get(reg)
278
+ records_validated += registry.from_values(records, field=field)
279
+
280
+ # feature doesn't have registries and therefore can't create records from values
281
+ # ask users to pass records
282
+ if len(records_validated) == 0:
283
+ raise ValueError(
284
+ "Please pass a record (a `Record` object), not a string, e.g., via:"
285
+ " label"
286
+ f" = ln.ULabel(name='{records[0]}')" # type: ignore
287
+ )
288
+ records = records_validated
289
+
290
+ for record in records:
291
+ if record._state.adding:
292
+ raise ValidationError(
293
+ f"{record} not validated. If it looks correct: record.save()"
294
+ )
295
+
296
+ if feature is None:
297
+ d = dict_related_model_to_related_name(self.__class__)
298
+ # strategy: group records by registry to reduce number of transactions
299
+ records_by_related_name: dict = {}
300
+ for record in records:
301
+ related_name = d.get(record.__class__.__get_name_with_schema__())
302
+ if related_name is None:
303
+ raise ValueError(f"Can't add labels to {record.__class__} record!")
304
+ if related_name not in records_by_related_name:
305
+ records_by_related_name[related_name] = []
306
+ records_by_related_name[related_name].append(record)
307
+ for related_name, records in records_by_related_name.items():
308
+ getattr(self, related_name).add(*records)
309
+ else:
310
+ validate_feature(feature, records) # type:ignore
311
+ records_by_registry = defaultdict(list)
312
+ for record in records:
313
+ records_by_registry[record.__class__.__get_name_with_schema__()].append(
314
+ record
315
+ )
316
+ for registry_name, records in records_by_registry.items():
317
+ if registry_name not in self.features._accessor_by_registry:
318
+ logger.warning(f"skipping {registry_name}")
319
+ continue
320
+ labels_accessor = getattr(
321
+ self, self.features._accessor_by_registry[registry_name]
322
+ )
323
+ # remove labels that are already linked as add doesn't perform update
324
+ linked_labels = [r for r in records if r in labels_accessor.filter()]
325
+ if len(linked_labels) > 0:
326
+ labels_accessor.remove(*linked_labels)
327
+ labels_accessor.add(*records, through_defaults={"feature_id": feature.id})
328
+ links_feature_set = get_feature_set_links(self)
329
+ feature_set_ids = [link.featureset_id for link in links_feature_set.all()]
330
+ # get all linked features of type Feature
331
+ feature_sets = FeatureSet.filter(id__in=feature_set_ids).all()
332
+ {
333
+ links_feature_set.filter(featureset_id=feature_set.id)
334
+ .one()
335
+ .slot: feature_set.features.all()
336
+ for feature_set in feature_sets
337
+ if "Feature" == feature_set.registry
338
+ }
339
+ for registry_name, _ in records_by_registry.items():
340
+ if registry_name not in feature.dtype:
341
+ logger.debug(
342
+ f"updated categorical feature '{feature.name}' type with registry '{registry_name}'"
343
+ )
344
+ if not feature.dtype.startswith("cat["):
345
+ feature.dtype = f"cat[{registry_name}]"
346
+ elif registry_name not in feature.dtype:
347
+ feature.dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
348
+ feature.save()
349
+
350
+
351
+ def _track_run_input(
352
+ data: Artifact | Collection | Iterable[Artifact] | Iterable[Collection],
353
+ is_run_input: bool | None = None,
354
+ run: Run | None = None,
355
+ ):
356
+ # this is an internal hack right now for project-flow, but we can allow this
357
+ # for the user in the future
358
+ if isinstance(is_run_input, Run):
359
+ run = is_run_input
360
+ is_run_input = True
361
+ elif run is None:
362
+ run = context.run
363
+ # consider that data is an iterable of Data
364
+ data_iter: Iterable[Artifact] | Iterable[Collection] = (
365
+ [data] if isinstance(data, (Artifact, Collection)) else data
366
+ )
367
+ track_run_input = False
368
+ input_data = []
369
+ if run is not None:
370
+ # avoid cycles: data can't be both input and output
371
+ def is_valid_input(data: Artifact | Collection):
372
+ return (
373
+ data.run_id != run.id
374
+ and not data._state.adding
375
+ and data._state.db in {"default", None}
376
+ )
377
+
378
+ input_data = [data for data in data_iter if is_valid_input(data)]
379
+ input_data_ids = [data.id for data in input_data]
380
+ if input_data:
381
+ data_class_name = input_data[0].__class__.__name__.lower()
382
+ # let us first look at the case in which the user does not
383
+ # provide a boolean value for `is_run_input`
384
+ # hence, we need to determine whether we actually want to
385
+ # track a run or not
386
+ if is_run_input is None:
387
+ # we don't have a run record
388
+ if run is None:
389
+ if settings.track_run_inputs:
390
+ logger.warning(WARNING_NO_INPUT)
391
+ # assume we have a run record
392
+ else:
393
+ # assume there is non-cyclic candidate input data
394
+ if input_data:
395
+ if settings.track_run_inputs:
396
+ transform_note = ""
397
+ if len(input_data) == 1:
398
+ if input_data[0].transform is not None:
399
+ transform_note = (
400
+ ", adding parent transform"
401
+ f" {input_data[0].transform.id}"
402
+ )
403
+ logger.info(
404
+ f"adding {data_class_name} ids {input_data_ids} as inputs for run"
405
+ f" {run.id}{transform_note}"
406
+ )
407
+ track_run_input = True
408
+ else:
409
+ logger.hint(
410
+ "track these data as a run input by passing `is_run_input=True`"
411
+ )
412
+ else:
413
+ track_run_input = is_run_input
414
+ if track_run_input:
415
+ if run is None:
416
+ raise ValueError(
417
+ "No run context set. Call ln.context.track() or link input to a"
418
+ " run object via `run.input_artifacts.add(artifact)`"
419
+ )
420
+ # avoid adding the same run twice
421
+ run.save()
422
+ if data_class_name == "artifact":
423
+ LinkORM = run.input_artifacts.through
424
+ links = [
425
+ LinkORM(run_id=run.id, artifact_id=data_id)
426
+ for data_id in input_data_ids
427
+ ]
428
+ else:
429
+ LinkORM = run.input_collections.through
430
+ links = [
431
+ LinkORM(run_id=run.id, collection_id=data_id)
432
+ for data_id in input_data_ids
433
+ ]
434
+ LinkORM.objects.bulk_create(links, ignore_conflicts=True)
435
+ # generalize below for more than one data batch
436
+ if len(input_data) == 1:
437
+ if input_data[0].transform is not None:
438
+ run.transform.predecessors.add(input_data[0].transform)