lamindb 0.76.8__py3-none-any.whl → 0.76.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. lamindb/__init__.py +114 -113
  2. lamindb/_artifact.py +1206 -1205
  3. lamindb/_can_validate.py +621 -579
  4. lamindb/_collection.py +390 -387
  5. lamindb/_curate.py +1603 -1601
  6. lamindb/_feature.py +155 -155
  7. lamindb/_feature_set.py +244 -242
  8. lamindb/_filter.py +23 -23
  9. lamindb/_finish.py +250 -256
  10. lamindb/_from_values.py +403 -382
  11. lamindb/_is_versioned.py +40 -40
  12. lamindb/_parents.py +476 -476
  13. lamindb/_query_manager.py +125 -125
  14. lamindb/_query_set.py +364 -362
  15. lamindb/_record.py +668 -649
  16. lamindb/_run.py +60 -57
  17. lamindb/_save.py +310 -308
  18. lamindb/_storage.py +14 -14
  19. lamindb/_transform.py +130 -127
  20. lamindb/_ulabel.py +56 -56
  21. lamindb/_utils.py +9 -9
  22. lamindb/_view.py +72 -72
  23. lamindb/core/__init__.py +94 -94
  24. lamindb/core/_context.py +590 -574
  25. lamindb/core/_data.py +510 -438
  26. lamindb/core/_django.py +209 -0
  27. lamindb/core/_feature_manager.py +994 -867
  28. lamindb/core/_label_manager.py +289 -253
  29. lamindb/core/_mapped_collection.py +631 -597
  30. lamindb/core/_settings.py +188 -187
  31. lamindb/core/_sync_git.py +138 -138
  32. lamindb/core/_track_environment.py +27 -27
  33. lamindb/core/datasets/__init__.py +59 -59
  34. lamindb/core/datasets/_core.py +581 -571
  35. lamindb/core/datasets/_fake.py +36 -36
  36. lamindb/core/exceptions.py +90 -90
  37. lamindb/core/fields.py +12 -12
  38. lamindb/core/loaders.py +164 -164
  39. lamindb/core/schema.py +56 -56
  40. lamindb/core/storage/__init__.py +25 -25
  41. lamindb/core/storage/_anndata_accessor.py +741 -740
  42. lamindb/core/storage/_anndata_sizes.py +41 -41
  43. lamindb/core/storage/_backed_access.py +98 -98
  44. lamindb/core/storage/_tiledbsoma.py +204 -204
  45. lamindb/core/storage/_valid_suffixes.py +21 -21
  46. lamindb/core/storage/_zarr.py +110 -110
  47. lamindb/core/storage/objects.py +62 -62
  48. lamindb/core/storage/paths.py +172 -172
  49. lamindb/core/subsettings/__init__.py +12 -12
  50. lamindb/core/subsettings/_creation_settings.py +38 -38
  51. lamindb/core/subsettings/_transform_settings.py +21 -21
  52. lamindb/core/types.py +19 -19
  53. lamindb/core/versioning.py +146 -158
  54. lamindb/integrations/__init__.py +12 -12
  55. lamindb/integrations/_vitessce.py +107 -107
  56. lamindb/setup/__init__.py +14 -14
  57. lamindb/setup/core/__init__.py +4 -4
  58. {lamindb-0.76.8.dist-info → lamindb-0.76.10.dist-info}/LICENSE +201 -201
  59. {lamindb-0.76.8.dist-info → lamindb-0.76.10.dist-info}/METADATA +8 -8
  60. lamindb-0.76.10.dist-info/RECORD +61 -0
  61. {lamindb-0.76.8.dist-info → lamindb-0.76.10.dist-info}/WHEEL +1 -1
  62. lamindb-0.76.8.dist-info/RECORD +0 -60
lamindb/core/_data.py CHANGED
@@ -1,438 +1,510 @@
1
- from __future__ import annotations
2
-
3
- from collections import defaultdict
4
- from typing import TYPE_CHECKING, Any, Iterable, List
5
-
6
- from lamin_utils import colors, logger
7
- from lamindb_setup.core._docs import doc_args
8
- from lnschema_core.models import (
9
- Artifact,
10
- Collection,
11
- Feature,
12
- FeatureSet,
13
- Record,
14
- Run,
15
- ULabel,
16
- format_field_value,
17
- record_repr,
18
- )
19
-
20
- from lamindb._parents import view_lineage
21
- from lamindb._query_set import QuerySet
22
- from lamindb._record import get_name_field
23
- from lamindb.core._settings import settings
24
-
25
- from ._context import context
26
- from ._feature_manager import (
27
- get_feature_set_links,
28
- get_host_id_field,
29
- get_label_links,
30
- print_features,
31
- )
32
- from ._label_manager import print_labels
33
- from .exceptions import ValidationError
34
- from .schema import (
35
- dict_related_model_to_related_name,
36
- dict_schema_name_to_model_name,
37
- )
38
-
39
- if TYPE_CHECKING:
40
- from lnschema_core.types import StrField
41
-
42
-
43
- WARNING_RUN_TRANSFORM = (
44
- "no run & transform got linked, call `ln.context.track()` & re-run"
45
- )
46
-
47
- WARNING_NO_INPUT = "run input wasn't tracked, call `ln.context.track()` and re-run"
48
-
49
-
50
- def get_run(run: Run | None) -> Run | None:
51
- if run is None:
52
- run = context.run
53
- if run is None and not settings.creation.artifact_silence_missing_run_warning:
54
- logger.warning(WARNING_RUN_TRANSFORM)
55
- # suppress run by passing False
56
- elif not run:
57
- run = None
58
- return run
59
-
60
-
61
- def add_transform_to_kwargs(kwargs: dict[str, Any], run: Run):
62
- if run is not None:
63
- kwargs["transform"] = run.transform
64
-
65
-
66
- def save_feature_sets(self: Artifact | Collection) -> None:
67
- if hasattr(self, "_feature_sets"):
68
- saved_feature_sets = {}
69
- for key, feature_set in self._feature_sets.items():
70
- if isinstance(feature_set, FeatureSet) and feature_set._state.adding:
71
- feature_set.save()
72
- saved_feature_sets[key] = feature_set
73
- if len(saved_feature_sets) > 0:
74
- s = "s" if len(saved_feature_sets) > 1 else ""
75
- display_feature_set_keys = ",".join(
76
- f"'{key}'" for key in saved_feature_sets.keys()
77
- )
78
- logger.save(
79
- f"saved {len(saved_feature_sets)} feature set{s} for slot{s}:"
80
- f" {display_feature_set_keys}"
81
- )
82
-
83
-
84
- def save_feature_set_links(self: Artifact | Collection) -> None:
85
- from lamindb._save import bulk_create
86
-
87
- Data = self.__class__
88
- if hasattr(self, "_feature_sets"):
89
- links = []
90
- host_id_field = get_host_id_field(self)
91
- for slot, feature_set in self._feature_sets.items():
92
- kwargs = {
93
- host_id_field: self.id,
94
- "featureset_id": feature_set.id,
95
- "slot": slot,
96
- }
97
- links.append(Data.feature_sets.through(**kwargs))
98
- bulk_create(links, ignore_conflicts=True)
99
-
100
-
101
- @doc_args(Artifact.describe.__doc__)
102
- def describe(self: Artifact, print_types: bool = False):
103
- """{}""" # noqa: D415
104
- model_name = self.__class__.__name__
105
- msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
106
- if self._state.db is not None and self._state.db != "default":
107
- msg += f" {colors.italic('Database instance')}\n"
108
- msg += f" slug: {self._state.db}\n"
109
- # prefetch all many-to-many relationships
110
- # doesn't work for describing using artifact
111
- # self = (
112
- # self.__class__.objects.using(self._state.db)
113
- # .prefetch_related(
114
- # *[f.name for f in self.__class__._meta.get_fields() if f.many_to_many]
115
- # )
116
- # .get(id=self.id)
117
- # )
118
-
119
- prov_msg = ""
120
- fields = self._meta.fields
121
- direct_fields = []
122
- foreign_key_fields = []
123
- for f in fields:
124
- if f.is_relation:
125
- foreign_key_fields.append(f.name)
126
- else:
127
- direct_fields.append(f.name)
128
- if not self._state.adding:
129
- # prefetch foreign key relationships
130
- self = (
131
- self.__class__.objects.using(self._state.db)
132
- .select_related(*foreign_key_fields)
133
- .get(id=self.id)
134
- )
135
- # prefetch m-2-m relationships
136
- many_to_many_fields = []
137
- if isinstance(self, (Collection, Artifact)):
138
- many_to_many_fields.append("input_of_runs")
139
- if isinstance(self, Artifact):
140
- many_to_many_fields.append("feature_sets")
141
- self = (
142
- self.__class__.objects.using(self._state.db)
143
- .prefetch_related(*many_to_many_fields)
144
- .get(id=self.id)
145
- )
146
-
147
- # provenance
148
- if len(foreign_key_fields) > 0: # always True for Artifact and Collection
149
- fields_values = [(field, getattr(self, field)) for field in foreign_key_fields]
150
- type_str = lambda attr: (
151
- f": {attr.__class__.__get_name_with_schema__()}" if print_types else ""
152
- )
153
- related_msg = "".join(
154
- [
155
- f" .{field_name}{type_str(attr)} = {format_field_value(getattr(attr, get_name_field(attr)))}\n"
156
- for (field_name, attr) in fields_values
157
- if attr is not None
158
- ]
159
- )
160
- prov_msg += related_msg
161
- if prov_msg:
162
- msg += f" {colors.italic('Provenance')}\n"
163
- msg += prov_msg
164
-
165
- # input of runs
166
- input_of_message = ""
167
- if self.id is not None and self.input_of_runs.exists():
168
- values = [format_field_value(i.started_at) for i in self.input_of_runs.all()]
169
- type_str = ": Run" if print_types else "" # type: ignore
170
- input_of_message += f" .input_of_runs{type_str} = {', '.join(values)}\n"
171
- if input_of_message:
172
- msg += f" {colors.italic('Usage')}\n"
173
- msg += input_of_message
174
-
175
- # labels
176
- msg += print_labels(self, print_types=print_types)
177
-
178
- # features
179
- if isinstance(self, Artifact):
180
- msg += print_features( # type: ignore
181
- self,
182
- print_types=print_types,
183
- print_params=hasattr(self, "type") and self.type == "model",
184
- )
185
-
186
- # print entire message
187
- logger.print(msg)
188
-
189
-
190
- def validate_feature(feature: Feature, records: list[Record]) -> None:
191
- """Validate feature record, adjust feature.dtype based on labels records."""
192
- if not isinstance(feature, Feature):
193
- raise TypeError("feature has to be of type Feature")
194
- if feature._state.adding:
195
- registries = {record.__class__.__get_name_with_schema__() for record in records}
196
- registries_str = "|".join(registries)
197
- msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
198
- raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
199
-
200
-
201
- def get_labels(
202
- self,
203
- feature: Feature,
204
- mute: bool = False,
205
- flat_names: bool = False,
206
- ) -> QuerySet | dict[str, QuerySet] | list:
207
- """{}""" # noqa: D415
208
- if not isinstance(feature, Feature):
209
- raise TypeError("feature has to be of type Feature")
210
- if feature.dtype is None or not feature.dtype.startswith("cat["):
211
- raise ValueError("feature does not have linked labels")
212
- registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
213
- if len(registries_to_check) > 1 and not mute:
214
- logger.warning("labels come from multiple registries!")
215
- # return an empty query set if self.id is still None
216
- if self.id is None:
217
- return QuerySet(self.__class__)
218
- qs_by_registry = {}
219
- for registry in registries_to_check:
220
- # currently need to distinguish between ULabel and non-ULabel, because
221
- # we only have the feature information for Label
222
- if registry == "ULabel":
223
- links_to_labels = get_label_links(self, registry, feature)
224
- label_ids = [link.ulabel_id for link in links_to_labels]
225
- qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
226
- id__in=label_ids
227
- )
228
- elif registry in self.features._accessor_by_registry:
229
- qs_by_registry[registry] = getattr(
230
- self, self.features._accessor_by_registry[registry]
231
- ).all()
232
- if flat_names:
233
- # returns a flat list of names
234
- from lamindb._record import get_name_field
235
-
236
- values = []
237
- for v in qs_by_registry.values():
238
- values += v.list(get_name_field(v))
239
- return values
240
- if len(registries_to_check) == 1 and registry in qs_by_registry:
241
- return qs_by_registry[registry]
242
- else:
243
- return qs_by_registry
244
-
245
-
246
- def add_labels(
247
- self,
248
- records: Record | list[Record] | QuerySet | Iterable,
249
- feature: Feature | None = None,
250
- *,
251
- field: StrField | None = None,
252
- ) -> None:
253
- """{}""" # noqa: D415
254
- if self._state.adding:
255
- raise ValueError("Please save the artifact/collection before adding a label!")
256
-
257
- if isinstance(records, (QuerySet, QuerySet.__base__)): # need to have both
258
- records = records.list()
259
- if isinstance(records, (str, Record)):
260
- records = [records]
261
- if not isinstance(records, List): # avoids warning for pd Series
262
- records = list(records)
263
- # create records from values
264
- if len(records) == 0:
265
- return None
266
- if isinstance(records[0], str): # type: ignore
267
- records_validated = []
268
- # feature is needed if we want to create records from values
269
- if feature is None:
270
- raise ValueError(
271
- "Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
272
- " feature=ln.Feature(name='my_feature'))"
273
- )
274
- if feature.dtype.startswith("cat["):
275
- orm_dict = dict_schema_name_to_model_name(Artifact)
276
- for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
277
- registry = orm_dict.get(reg)
278
- records_validated += registry.from_values(records, field=field)
279
-
280
- # feature doesn't have registries and therefore can't create records from values
281
- # ask users to pass records
282
- if len(records_validated) == 0:
283
- raise ValueError(
284
- "Please pass a record (a `Record` object), not a string, e.g., via:"
285
- " label"
286
- f" = ln.ULabel(name='{records[0]}')" # type: ignore
287
- )
288
- records = records_validated
289
-
290
- for record in records:
291
- if record._state.adding:
292
- raise ValidationError(
293
- f"{record} not validated. If it looks correct: record.save()"
294
- )
295
-
296
- if feature is None:
297
- d = dict_related_model_to_related_name(self.__class__)
298
- # strategy: group records by registry to reduce number of transactions
299
- records_by_related_name: dict = {}
300
- for record in records:
301
- related_name = d.get(record.__class__.__get_name_with_schema__())
302
- if related_name is None:
303
- raise ValueError(f"Can't add labels to {record.__class__} record!")
304
- if related_name not in records_by_related_name:
305
- records_by_related_name[related_name] = []
306
- records_by_related_name[related_name].append(record)
307
- for related_name, records in records_by_related_name.items():
308
- getattr(self, related_name).add(*records)
309
- else:
310
- validate_feature(feature, records) # type:ignore
311
- records_by_registry = defaultdict(list)
312
- for record in records:
313
- records_by_registry[record.__class__.__get_name_with_schema__()].append(
314
- record
315
- )
316
- for registry_name, records in records_by_registry.items():
317
- if registry_name not in self.features._accessor_by_registry:
318
- logger.warning(f"skipping {registry_name}")
319
- continue
320
- labels_accessor = getattr(
321
- self, self.features._accessor_by_registry[registry_name]
322
- )
323
- # remove labels that are already linked as add doesn't perform update
324
- linked_labels = [r for r in records if r in labels_accessor.filter()]
325
- if len(linked_labels) > 0:
326
- labels_accessor.remove(*linked_labels)
327
- labels_accessor.add(*records, through_defaults={"feature_id": feature.id})
328
- links_feature_set = get_feature_set_links(self)
329
- feature_set_ids = [link.featureset_id for link in links_feature_set.all()]
330
- # get all linked features of type Feature
331
- feature_sets = FeatureSet.filter(id__in=feature_set_ids).all()
332
- {
333
- links_feature_set.filter(featureset_id=feature_set.id)
334
- .one()
335
- .slot: feature_set.features.all()
336
- for feature_set in feature_sets
337
- if "Feature" == feature_set.registry
338
- }
339
- for registry_name, _ in records_by_registry.items():
340
- if registry_name not in feature.dtype:
341
- logger.debug(
342
- f"updated categorical feature '{feature.name}' type with registry '{registry_name}'"
343
- )
344
- if not feature.dtype.startswith("cat["):
345
- feature.dtype = f"cat[{registry_name}]"
346
- elif registry_name not in feature.dtype:
347
- feature.dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
348
- feature.save()
349
-
350
-
351
- def _track_run_input(
352
- data: Artifact | Collection | Iterable[Artifact] | Iterable[Collection],
353
- is_run_input: bool | None = None,
354
- run: Run | None = None,
355
- ):
356
- # this is an internal hack right now for project-flow, but we can allow this
357
- # for the user in the future
358
- if isinstance(is_run_input, Run):
359
- run = is_run_input
360
- is_run_input = True
361
- elif run is None:
362
- run = context.run
363
- # consider that data is an iterable of Data
364
- data_iter: Iterable[Artifact] | Iterable[Collection] = (
365
- [data] if isinstance(data, (Artifact, Collection)) else data
366
- )
367
- track_run_input = False
368
- input_data = []
369
- if run is not None:
370
- # avoid cycles: data can't be both input and output
371
- def is_valid_input(data: Artifact | Collection):
372
- return (
373
- data.run_id != run.id
374
- and not data._state.adding
375
- and data._state.db in {"default", None}
376
- )
377
-
378
- input_data = [data for data in data_iter if is_valid_input(data)]
379
- input_data_ids = [data.id for data in input_data]
380
- if input_data:
381
- data_class_name = input_data[0].__class__.__name__.lower()
382
- # let us first look at the case in which the user does not
383
- # provide a boolean value for `is_run_input`
384
- # hence, we need to determine whether we actually want to
385
- # track a run or not
386
- if is_run_input is None:
387
- # we don't have a run record
388
- if run is None:
389
- if settings.track_run_inputs:
390
- logger.warning(WARNING_NO_INPUT)
391
- # assume we have a run record
392
- else:
393
- # assume there is non-cyclic candidate input data
394
- if input_data:
395
- if settings.track_run_inputs:
396
- transform_note = ""
397
- if len(input_data) == 1:
398
- if input_data[0].transform is not None:
399
- transform_note = (
400
- ", adding parent transform"
401
- f" {input_data[0].transform.id}"
402
- )
403
- logger.info(
404
- f"adding {data_class_name} ids {input_data_ids} as inputs for run"
405
- f" {run.id}{transform_note}"
406
- )
407
- track_run_input = True
408
- else:
409
- logger.hint(
410
- "track these data as a run input by passing `is_run_input=True`"
411
- )
412
- else:
413
- track_run_input = is_run_input
414
- if track_run_input:
415
- if run is None:
416
- raise ValueError(
417
- "No run context set. Call ln.context.track() or link input to a"
418
- " run object via `run.input_artifacts.add(artifact)`"
419
- )
420
- # avoid adding the same run twice
421
- run.save()
422
- if data_class_name == "artifact":
423
- LinkORM = run.input_artifacts.through
424
- links = [
425
- LinkORM(run_id=run.id, artifact_id=data_id)
426
- for data_id in input_data_ids
427
- ]
428
- else:
429
- LinkORM = run.input_collections.through
430
- links = [
431
- LinkORM(run_id=run.id, collection_id=data_id)
432
- for data_id in input_data_ids
433
- ]
434
- LinkORM.objects.bulk_create(links, ignore_conflicts=True)
435
- # generalize below for more than one data batch
436
- if len(input_data) == 1:
437
- if input_data[0].transform is not None:
438
- run.transform.predecessors.add(input_data[0].transform)
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from django.db import connections
7
+ from lamin_utils import colors, logger
8
+ from lamindb_setup.core._docs import doc_args
9
+ from lnschema_core.models import (
10
+ Artifact,
11
+ Collection,
12
+ Feature,
13
+ FeatureSet,
14
+ Record,
15
+ Run,
16
+ ULabel,
17
+ format_field_value,
18
+ record_repr,
19
+ )
20
+
21
+ from lamindb._parents import view_lineage
22
+ from lamindb._query_set import QuerySet
23
+ from lamindb._record import get_name_field
24
+ from lamindb.core._settings import settings
25
+
26
+ from ._context import context
27
+ from ._django import get_artifact_with_related, get_related_model
28
+ from ._feature_manager import (
29
+ get_feature_set_links,
30
+ get_host_id_field,
31
+ get_label_links,
32
+ print_features,
33
+ )
34
+ from ._label_manager import print_labels
35
+ from .exceptions import ValidationError
36
+ from .schema import (
37
+ dict_related_model_to_related_name,
38
+ dict_schema_name_to_model_name,
39
+ )
40
+
41
+ if TYPE_CHECKING:
42
+ from collections.abc import Iterable
43
+
44
+ from lnschema_core.types import StrField
45
+
46
+
47
+ WARNING_RUN_TRANSFORM = "no run & transform got linked, call `ln.track()` & re-run"
48
+
49
+ WARNING_NO_INPUT = "run input wasn't tracked, call `ln.track()` and re-run"
50
+
51
+
52
+ def get_run(run: Run | None) -> Run | None:
53
+ if run is None:
54
+ run = context.run
55
+ if run is None and not settings.creation.artifact_silence_missing_run_warning:
56
+ logger.warning(WARNING_RUN_TRANSFORM)
57
+ # suppress run by passing False
58
+ elif not run:
59
+ run = None
60
+ return run
61
+
62
+
63
+ def add_transform_to_kwargs(kwargs: dict[str, Any], run: Run):
64
+ if run is not None:
65
+ kwargs["transform"] = run.transform
66
+
67
+
68
+ def save_feature_sets(self: Artifact | Collection) -> None:
69
+ if hasattr(self, "_feature_sets"):
70
+ saved_feature_sets = {}
71
+ for key, feature_set in self._feature_sets.items():
72
+ if isinstance(feature_set, FeatureSet) and feature_set._state.adding:
73
+ feature_set.save()
74
+ saved_feature_sets[key] = feature_set
75
+ if len(saved_feature_sets) > 0:
76
+ s = "s" if len(saved_feature_sets) > 1 else ""
77
+ display_feature_set_keys = ",".join(
78
+ f"'{key}'" for key in saved_feature_sets.keys()
79
+ )
80
+ logger.save(
81
+ f"saved {len(saved_feature_sets)} feature set{s} for slot{s}:"
82
+ f" {display_feature_set_keys}"
83
+ )
84
+
85
+
86
+ def save_feature_set_links(self: Artifact | Collection) -> None:
87
+ from lamindb._save import bulk_create
88
+
89
+ Data = self.__class__
90
+ if hasattr(self, "_feature_sets"):
91
+ links = []
92
+ host_id_field = get_host_id_field(self)
93
+ for slot, feature_set in self._feature_sets.items():
94
+ kwargs = {
95
+ host_id_field: self.id,
96
+ "featureset_id": feature_set.id,
97
+ "slot": slot,
98
+ }
99
+ links.append(Data.feature_sets.through(**kwargs))
100
+ bulk_create(links, ignore_conflicts=True)
101
+
102
+
103
+ def format_provenance(self, fk_data, print_types):
104
+ type_str = lambda attr: (
105
+ f": {get_related_model(self.__class__, attr).__name__}" if print_types else ""
106
+ )
107
+
108
+ return "".join(
109
+ [
110
+ f" .{field_name}{type_str(field_name)} = {format_field_value(value.get('name'))}\n"
111
+ for field_name, value in fk_data.items()
112
+ if value.get("name")
113
+ ]
114
+ )
115
+
116
+
117
+ def format_input_of_runs(self, print_types):
118
+ if self.id is not None and self.input_of_runs.exists():
119
+ values = [format_field_value(i.started_at) for i in self.input_of_runs.all()]
120
+ type_str = ": Run" if print_types else "" # type: ignore
121
+ return f" .input_of_runs{type_str} = {', '.join(values)}\n"
122
+ return ""
123
+
124
+
125
+ def format_labels_and_features(self, related_data, print_types):
126
+ msg = print_labels(
127
+ self, m2m_data=related_data.get("m2m", {}), print_types=print_types
128
+ )
129
+ if isinstance(self, Artifact):
130
+ msg += print_features( # type: ignore
131
+ self,
132
+ related_data=related_data,
133
+ print_types=print_types,
134
+ print_params=hasattr(self, "type") and self.type == "model",
135
+ )
136
+ return msg
137
+
138
+
139
+ def _describe_postgres(self: Artifact | Collection, print_types: bool = False):
140
+ model_name = self.__class__.__name__
141
+ msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
142
+ if self._state.db is not None and self._state.db != "default":
143
+ msg += f" {colors.italic('Database instance')}\n"
144
+ msg += f" slug: {self._state.db}\n"
145
+
146
+ if model_name == "Artifact":
147
+ result = get_artifact_with_related(
148
+ self,
149
+ include_feature_link=True,
150
+ include_fk=True,
151
+ include_m2m=True,
152
+ include_featureset=True,
153
+ )
154
+ else:
155
+ result = get_artifact_with_related(self, include_fk=True, include_m2m=True)
156
+ related_data = result.get("related_data", {})
157
+ fk_data = related_data.get("fk", {})
158
+
159
+ # Provenance
160
+ prov_msg = format_provenance(self, fk_data, print_types)
161
+ if prov_msg:
162
+ msg += f" {colors.italic('Provenance')}\n{prov_msg}"
163
+
164
+ # Input of runs
165
+ input_of_message = format_input_of_runs(self, print_types)
166
+ if input_of_message:
167
+ msg += f" {colors.italic('Usage')}\n{input_of_message}"
168
+
169
+ # Labels and features
170
+ msg += format_labels_and_features(self, related_data, print_types)
171
+
172
+ # Print entire message
173
+ logger.print(msg)
174
+
175
+
176
+ @doc_args(Artifact.describe.__doc__)
177
+ def describe(self: Artifact | Collection, print_types: bool = False):
178
+ """{}""" # noqa: D415
179
+ if not self._state.adding and connections[self._state.db].vendor == "postgresql":
180
+ return _describe_postgres(self, print_types=print_types)
181
+
182
+ model_name = self.__class__.__name__
183
+ msg = f"{colors.green(model_name)}{record_repr(self, include_foreign_keys=False).lstrip(model_name)}\n"
184
+ if self._state.db is not None and self._state.db != "default":
185
+ msg += f" {colors.italic('Database instance')}\n"
186
+ msg += f" slug: {self._state.db}\n"
187
+
188
+ prov_msg = ""
189
+ fields = self._meta.fields
190
+ direct_fields = []
191
+ foreign_key_fields = []
192
+ for f in fields:
193
+ if f.is_relation:
194
+ foreign_key_fields.append(f.name)
195
+ else:
196
+ direct_fields.append(f.name)
197
+ if not self._state.adding:
198
+ # prefetch foreign key relationships
199
+ self = (
200
+ self.__class__.objects.using(self._state.db)
201
+ .select_related(*foreign_key_fields)
202
+ .get(id=self.id)
203
+ )
204
+ # prefetch m-2-m relationships
205
+ many_to_many_fields = []
206
+ if isinstance(self, (Collection, Artifact)):
207
+ many_to_many_fields.append("input_of_runs")
208
+ if isinstance(self, Artifact):
209
+ many_to_many_fields.append("feature_sets")
210
+ self = (
211
+ self.__class__.objects.using(self._state.db)
212
+ .prefetch_related(*many_to_many_fields)
213
+ .get(id=self.id)
214
+ )
215
+
216
+ # provenance
217
+ if len(foreign_key_fields) > 0: # always True for Artifact and Collection
218
+ fields_values = [(field, getattr(self, field)) for field in foreign_key_fields]
219
+ type_str = lambda attr: (
220
+ f": {attr.__class__.__get_name_with_schema__()}" if print_types else ""
221
+ )
222
+ related_msg = "".join(
223
+ [
224
+ f" .{field_name}{type_str(attr)} = {format_field_value(getattr(attr, get_name_field(attr)))}\n"
225
+ for (field_name, attr) in fields_values
226
+ if attr is not None
227
+ ]
228
+ )
229
+ prov_msg += related_msg
230
+ if prov_msg:
231
+ msg += f" {colors.italic('Provenance')}\n"
232
+ msg += prov_msg
233
+
234
+ # Input of runs
235
+ input_of_message = format_input_of_runs(self, print_types)
236
+ if input_of_message:
237
+ msg += f" {colors.italic('Usage')}\n{input_of_message}"
238
+
239
+ # Labels and features
240
+ msg += format_labels_and_features(self, {}, print_types)
241
+
242
+ # Print entire message
243
+ logger.print(msg)
244
+
245
+
246
+ def validate_feature(feature: Feature, records: list[Record]) -> None:
247
+ """Validate feature record, adjust feature.dtype based on labels records."""
248
+ if not isinstance(feature, Feature):
249
+ raise TypeError("feature has to be of type Feature")
250
+ if feature._state.adding:
251
+ registries = {record.__class__.__get_name_with_schema__() for record in records}
252
+ registries_str = "|".join(registries)
253
+ msg = f"ln.Feature(name='{feature.name}', type='cat[{registries_str}]').save()"
254
+ raise ValidationError(f"Feature not validated. If it looks correct: {msg}")
255
+
256
+
257
+ def get_labels(
258
+ self,
259
+ feature: Feature,
260
+ mute: bool = False,
261
+ flat_names: bool = False,
262
+ ) -> QuerySet | dict[str, QuerySet] | list:
263
+ """{}""" # noqa: D415
264
+ if not isinstance(feature, Feature):
265
+ raise TypeError("feature has to be of type Feature")
266
+ if feature.dtype is None or not feature.dtype.startswith("cat["):
267
+ raise ValueError("feature does not have linked labels")
268
+ registries_to_check = feature.dtype.replace("cat[", "").rstrip("]").split("|")
269
+ if len(registries_to_check) > 1 and not mute:
270
+ logger.warning("labels come from multiple registries!")
271
+ # return an empty query set if self.id is still None
272
+ if self.id is None:
273
+ return QuerySet(self.__class__)
274
+ qs_by_registry = {}
275
+ for registry in registries_to_check:
276
+ # currently need to distinguish between ULabel and non-ULabel, because
277
+ # we only have the feature information for Label
278
+ if registry == "ULabel":
279
+ links_to_labels = get_label_links(self, registry, feature)
280
+ label_ids = [link.ulabel_id for link in links_to_labels]
281
+ qs_by_registry[registry] = ULabel.objects.using(self._state.db).filter(
282
+ id__in=label_ids
283
+ )
284
+ elif registry in self.features._accessor_by_registry:
285
+ qs_by_registry[registry] = getattr(
286
+ self, self.features._accessor_by_registry[registry]
287
+ ).all()
288
+ if flat_names:
289
+ # returns a flat list of names
290
+ from lamindb._record import get_name_field
291
+
292
+ values = []
293
+ for v in qs_by_registry.values():
294
+ values += v.list(get_name_field(v))
295
+ return values
296
+ if len(registries_to_check) == 1 and registry in qs_by_registry:
297
+ return qs_by_registry[registry]
298
+ else:
299
+ return qs_by_registry
300
+
301
+
302
+ def add_labels(
303
+ self,
304
+ records: Record | list[Record] | QuerySet | Iterable,
305
+ feature: Feature | None = None,
306
+ *,
307
+ field: StrField | None = None,
308
+ ) -> None:
309
+ """{}""" # noqa: D415
310
+ if self._state.adding:
311
+ raise ValueError("Please save the artifact/collection before adding a label!")
312
+
313
+ if isinstance(records, (QuerySet, QuerySet.__base__)): # need to have both
314
+ records = records.list()
315
+ if isinstance(records, (str, Record)):
316
+ records = [records]
317
+ if not isinstance(records, list): # avoids warning for pd Series
318
+ records = list(records)
319
+ # create records from values
320
+ if len(records) == 0:
321
+ return None
322
+ if isinstance(records[0], str): # type: ignore
323
+ records_validated = []
324
+ # feature is needed if we want to create records from values
325
+ if feature is None:
326
+ raise ValueError(
327
+ "Please pass a feature, e.g., via: label = ln.ULabel(name='my_label',"
328
+ " feature=ln.Feature(name='my_feature'))"
329
+ )
330
+ if feature.dtype.startswith("cat["):
331
+ orm_dict = dict_schema_name_to_model_name(Artifact)
332
+ for reg in feature.dtype.replace("cat[", "").rstrip("]").split("|"):
333
+ registry = orm_dict.get(reg)
334
+ records_validated += registry.from_values(records, field=field)
335
+
336
+ # feature doesn't have registries and therefore can't create records from values
337
+ # ask users to pass records
338
+ if len(records_validated) == 0:
339
+ raise ValueError(
340
+ "Please pass a record (a `Record` object), not a string, e.g., via:"
341
+ " label"
342
+ f" = ln.ULabel(name='{records[0]}')" # type: ignore
343
+ )
344
+ records = records_validated
345
+
346
+ for record in records:
347
+ if record._state.adding:
348
+ raise ValidationError(
349
+ f"{record} not validated. If it looks correct: record.save()"
350
+ )
351
+
352
+ if feature is None:
353
+ d = dict_related_model_to_related_name(self.__class__)
354
+ # strategy: group records by registry to reduce number of transactions
355
+ records_by_related_name: dict = {}
356
+ for record in records:
357
+ related_name = d.get(record.__class__.__get_name_with_schema__())
358
+ if related_name is None:
359
+ raise ValueError(f"Can't add labels to {record.__class__} record!")
360
+ if related_name not in records_by_related_name:
361
+ records_by_related_name[related_name] = []
362
+ records_by_related_name[related_name].append(record)
363
+ for related_name, records in records_by_related_name.items():
364
+ getattr(self, related_name).add(*records)
365
+ else:
366
+ validate_feature(feature, records) # type:ignore
367
+ records_by_registry = defaultdict(list)
368
+ for record in records:
369
+ records_by_registry[record.__class__.__get_name_with_schema__()].append(
370
+ record
371
+ )
372
+ for registry_name, records in records_by_registry.items():
373
+ if registry_name not in self.features._accessor_by_registry:
374
+ logger.warning(f"skipping {registry_name}")
375
+ continue
376
+ labels_accessor = getattr(
377
+ self, self.features._accessor_by_registry[registry_name]
378
+ )
379
+ # remove labels that are already linked as add doesn't perform update
380
+ linked_labels = [r for r in records if r in labels_accessor.filter()]
381
+ if len(linked_labels) > 0:
382
+ labels_accessor.remove(*linked_labels)
383
+ labels_accessor.add(*records, through_defaults={"feature_id": feature.id})
384
+ links_feature_set = get_feature_set_links(self)
385
+ feature_set_ids = [link.featureset_id for link in links_feature_set.all()]
386
+ # get all linked features of type Feature
387
+ feature_sets = FeatureSet.filter(id__in=feature_set_ids).all()
388
+ {
389
+ links_feature_set.filter(featureset_id=feature_set.id)
390
+ .one()
391
+ .slot: feature_set.features.all()
392
+ for feature_set in feature_sets
393
+ if "Feature" == feature_set.registry
394
+ }
395
+ for registry_name, _ in records_by_registry.items():
396
+ if registry_name not in feature.dtype:
397
+ logger.debug(
398
+ f"updated categorical feature '{feature.name}' type with registry '{registry_name}'"
399
+ )
400
+ if not feature.dtype.startswith("cat["):
401
+ feature.dtype = f"cat[{registry_name}]"
402
+ elif registry_name not in feature.dtype:
403
+ feature.dtype = feature.dtype.rstrip("]") + f"|{registry_name}]"
404
+ feature.save()
405
+
406
+
407
+ def _track_run_input(
408
+ data: Artifact | Collection | Iterable[Artifact] | Iterable[Collection],
409
+ is_run_input: bool | None = None,
410
+ run: Run | None = None,
411
+ ):
412
+ # this is an internal hack right now for project-flow, but we can allow this
413
+ # for the user in the future
414
+ if isinstance(is_run_input, Run):
415
+ run = is_run_input
416
+ is_run_input = True
417
+ elif run is None:
418
+ run = context.run
419
+ # consider that data is an iterable of Data
420
+ data_iter: Iterable[Artifact] | Iterable[Collection] = (
421
+ [data] if isinstance(data, (Artifact, Collection)) else data
422
+ )
423
+ track_run_input = False
424
+ input_data = []
425
+ if run is not None:
426
+ # avoid cycles: data can't be both input and output
427
+ def is_valid_input(data: Artifact | Collection):
428
+ is_valid = False
429
+ if data._state.db == "default":
430
+ # things are OK if the record is on the default db
431
+ is_valid = True
432
+ elif data._state.db is None:
433
+ # if a record is not yet saved, it can't be an input
434
+ # we silently ignore because what likely happens is that
435
+ # the user works with an object that's about to be saved
436
+ # in the current Python session
437
+ is_valid = False
438
+ else:
439
+ # record is on another db
440
+ # we have to save the record into the current db with
441
+ # the run being attached to a transfer transform
442
+ logger.important(
443
+ f"completing transfer to track {data.__class__.__name__}('{data.uid[:8]}') as input"
444
+ )
445
+ data.save()
446
+ is_valid = True
447
+ return (
448
+ data.run_id != run.id
449
+ and not data._state.adding # this seems duplicated with data._state.db is None
450
+ and is_valid
451
+ )
452
+
453
+ input_data = [data for data in data_iter if is_valid_input(data)]
454
+ input_data_ids = [data.id for data in input_data]
455
+ if input_data:
456
+ data_class_name = input_data[0].__class__.__name__.lower()
457
+ # let us first look at the case in which the user does not
458
+ # provide a boolean value for `is_run_input`
459
+ # hence, we need to determine whether we actually want to
460
+ # track a run or not
461
+ if is_run_input is None:
462
+ # we don't have a run record
463
+ if run is None:
464
+ if settings.track_run_inputs:
465
+ logger.warning(WARNING_NO_INPUT)
466
+ # assume we have a run record
467
+ else:
468
+ # assume there is non-cyclic candidate input data
469
+ if input_data:
470
+ if settings.track_run_inputs:
471
+ transform_note = ""
472
+ if len(input_data) == 1:
473
+ if input_data[0].transform is not None:
474
+ transform_note = (
475
+ ", adding parent transform"
476
+ f" {input_data[0].transform.id}"
477
+ )
478
+ logger.info(
479
+ f"adding {data_class_name} ids {input_data_ids} as inputs for run"
480
+ f" {run.id}{transform_note}"
481
+ )
482
+ track_run_input = True
483
+ else:
484
+ logger.hint(
485
+ "track these data as a run input by passing `is_run_input=True`"
486
+ )
487
+ else:
488
+ track_run_input = is_run_input
489
+ if track_run_input:
490
+ if run is None:
491
+ raise ValueError("No run context set. Call `ln.track()`.")
492
+ # avoid adding the same run twice
493
+ run.save()
494
+ if data_class_name == "artifact":
495
+ LinkORM = run.input_artifacts.through
496
+ links = [
497
+ LinkORM(run_id=run.id, artifact_id=data_id)
498
+ for data_id in input_data_ids
499
+ ]
500
+ else:
501
+ LinkORM = run.input_collections.through
502
+ links = [
503
+ LinkORM(run_id=run.id, collection_id=data_id)
504
+ for data_id in input_data_ids
505
+ ]
506
+ LinkORM.objects.bulk_create(links, ignore_conflicts=True)
507
+ # generalize below for more than one data batch
508
+ if len(input_data) == 1:
509
+ if input_data[0].transform is not None:
510
+ run.transform.predecessors.add(input_data[0].transform)