lamindb 0.76.8__py3-none-any.whl → 0.76.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. lamindb/__init__.py +113 -113
  2. lamindb/_artifact.py +1205 -1205
  3. lamindb/_can_validate.py +579 -579
  4. lamindb/_collection.py +389 -387
  5. lamindb/_curate.py +1601 -1601
  6. lamindb/_feature.py +155 -155
  7. lamindb/_feature_set.py +242 -242
  8. lamindb/_filter.py +23 -23
  9. lamindb/_finish.py +256 -256
  10. lamindb/_from_values.py +382 -382
  11. lamindb/_is_versioned.py +40 -40
  12. lamindb/_parents.py +476 -476
  13. lamindb/_query_manager.py +125 -125
  14. lamindb/_query_set.py +362 -362
  15. lamindb/_record.py +649 -649
  16. lamindb/_run.py +57 -57
  17. lamindb/_save.py +308 -308
  18. lamindb/_storage.py +14 -14
  19. lamindb/_transform.py +127 -127
  20. lamindb/_ulabel.py +56 -56
  21. lamindb/_utils.py +9 -9
  22. lamindb/_view.py +72 -72
  23. lamindb/core/__init__.py +94 -94
  24. lamindb/core/_context.py +574 -574
  25. lamindb/core/_data.py +438 -438
  26. lamindb/core/_feature_manager.py +867 -867
  27. lamindb/core/_label_manager.py +253 -253
  28. lamindb/core/_mapped_collection.py +631 -597
  29. lamindb/core/_settings.py +187 -187
  30. lamindb/core/_sync_git.py +138 -138
  31. lamindb/core/_track_environment.py +27 -27
  32. lamindb/core/datasets/__init__.py +59 -59
  33. lamindb/core/datasets/_core.py +581 -571
  34. lamindb/core/datasets/_fake.py +36 -36
  35. lamindb/core/exceptions.py +90 -90
  36. lamindb/core/fields.py +12 -12
  37. lamindb/core/loaders.py +164 -164
  38. lamindb/core/schema.py +56 -56
  39. lamindb/core/storage/__init__.py +25 -25
  40. lamindb/core/storage/_anndata_accessor.py +740 -740
  41. lamindb/core/storage/_anndata_sizes.py +41 -41
  42. lamindb/core/storage/_backed_access.py +98 -98
  43. lamindb/core/storage/_tiledbsoma.py +204 -204
  44. lamindb/core/storage/_valid_suffixes.py +21 -21
  45. lamindb/core/storage/_zarr.py +110 -110
  46. lamindb/core/storage/objects.py +62 -62
  47. lamindb/core/storage/paths.py +172 -172
  48. lamindb/core/subsettings/__init__.py +12 -12
  49. lamindb/core/subsettings/_creation_settings.py +38 -38
  50. lamindb/core/subsettings/_transform_settings.py +21 -21
  51. lamindb/core/types.py +19 -19
  52. lamindb/core/versioning.py +158 -158
  53. lamindb/integrations/__init__.py +12 -12
  54. lamindb/integrations/_vitessce.py +107 -107
  55. lamindb/setup/__init__.py +14 -14
  56. lamindb/setup/core/__init__.py +4 -4
  57. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/LICENSE +201 -201
  58. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/METADATA +4 -4
  59. lamindb-0.76.9.dist-info/RECORD +60 -0
  60. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/WHEEL +1 -1
  61. lamindb-0.76.8.dist-info/RECORD +0 -60
@@ -1,867 +1,867 @@
1
- from __future__ import annotations
2
-
3
- from collections import defaultdict
4
- from collections.abc import Iterable
5
- from itertools import compress
6
- from typing import TYPE_CHECKING, Any
7
-
8
- import anndata as ad
9
- import numpy as np
10
- import pandas as pd
11
- from anndata import AnnData
12
- from django.contrib.postgres.aggregates import ArrayAgg
13
- from django.db import connections
14
- from django.db.models import Aggregate
15
- from lamin_utils import colors, logger
16
- from lamindb_setup.core.upath import create_path
17
- from lnschema_core.models import (
18
- Artifact,
19
- Collection,
20
- Feature,
21
- FeatureManager,
22
- FeatureValue,
23
- LinkORM,
24
- Param,
25
- ParamManager,
26
- ParamManagerArtifact,
27
- ParamManagerRun,
28
- ParamValue,
29
- Record,
30
- Run,
31
- ULabel,
32
- )
33
-
34
- from lamindb._feature import FEATURE_TYPES, convert_numpy_dtype_to_lamin_feature_type
35
- from lamindb._feature_set import DICT_KEYS_TYPE, FeatureSet
36
- from lamindb._record import (
37
- REGISTRY_UNIQUE_FIELD,
38
- get_name_field,
39
- transfer_fk_to_default_db_bulk,
40
- transfer_to_default_db,
41
- )
42
- from lamindb._save import save
43
- from lamindb.core.exceptions import ValidationError
44
- from lamindb.core.storage import LocalPathClasses
45
-
46
- from ._label_manager import get_labels_as_dict
47
- from ._settings import settings
48
- from .schema import (
49
- dict_related_model_to_related_name,
50
- )
51
-
52
- if TYPE_CHECKING:
53
- from lnschema_core.types import FieldAttr
54
-
55
- from lamindb._query_set import QuerySet
56
-
57
-
58
- def get_host_id_field(host: Artifact | Collection) -> str:
59
- if isinstance(host, Artifact):
60
- host_id_field = "artifact_id"
61
- else:
62
- host_id_field = "collection_id"
63
- return host_id_field
64
-
65
-
66
- def get_accessor_by_registry_(host: Artifact | Collection) -> dict:
67
- dictionary = {
68
- field.related_model.__get_name_with_schema__(): field.name
69
- for field in host._meta.related_objects
70
- }
71
- dictionary["Feature"] = "features"
72
- dictionary["ULabel"] = "ulabels"
73
- return dictionary
74
-
75
-
76
- def get_feature_set_by_slot_(host) -> dict:
77
- # if the host is not yet saved
78
- if host._state.adding:
79
- if hasattr(host, "_feature_sets"):
80
- return host._feature_sets
81
- else:
82
- return {}
83
- host_db = host._state.db
84
- host_id_field = get_host_id_field(host)
85
- kwargs = {host_id_field: host.id}
86
- # otherwise, we need a query
87
- links_feature_set = (
88
- host.feature_sets.through.objects.using(host_db)
89
- .filter(**kwargs)
90
- .select_related("featureset")
91
- )
92
- return {fsl.slot: fsl.featureset for fsl in links_feature_set}
93
-
94
-
95
- def get_label_links(
96
- host: Artifact | Collection, registry: str, feature: Feature
97
- ) -> QuerySet:
98
- host_id_field = get_host_id_field(host)
99
- kwargs = {host_id_field: host.id, "feature_id": feature.id}
100
- link_records = (
101
- getattr(host, host.features._accessor_by_registry[registry])
102
- .through.objects.using(host._state.db)
103
- .filter(**kwargs)
104
- )
105
- return link_records
106
-
107
-
108
- def get_feature_set_links(host: Artifact | Collection) -> QuerySet:
109
- host_id_field = get_host_id_field(host)
110
- kwargs = {host_id_field: host.id}
111
- links_feature_set = host.feature_sets.through.objects.filter(**kwargs)
112
- return links_feature_set
113
-
114
-
115
- def get_link_attr(link: LinkORM | type[LinkORM], data: Artifact | Collection) -> str:
116
- link_model_name = link.__class__.__name__
117
- if link_model_name in {"Registry", "ModelBase"}: # we passed the type of the link
118
- link_model_name = link.__name__
119
- return link_model_name.replace(data.__class__.__name__, "").lower()
120
-
121
-
122
- # Custom aggregation for SQLite
123
- class GroupConcat(Aggregate):
124
- function = "GROUP_CONCAT"
125
- template = '%(function)s(%(expressions)s, ", ")'
126
-
127
-
128
- def custom_aggregate(field, using: str):
129
- if connections[using].vendor == "postgresql":
130
- return ArrayAgg(field)
131
- else:
132
- return GroupConcat(field)
133
-
134
-
135
- def print_features(
136
- self: Artifact | Collection,
137
- print_types: bool = False,
138
- to_dict: bool = False,
139
- print_params: bool = False,
140
- ) -> str | dict[str, Any]:
141
- from lamindb._from_values import _print_values
142
-
143
- msg = ""
144
- dictionary = {}
145
- # categorical feature values
146
- if not print_params:
147
- labels_msg = ""
148
- labels_by_feature = defaultdict(list)
149
- for _, (_, links) in get_labels_as_dict(self, links=True).items():
150
- for link in links:
151
- if link.feature_id is not None:
152
- link_attr = get_link_attr(link, self)
153
- labels_by_feature[link.feature_id].append(
154
- getattr(link, link_attr).name
155
- )
156
- labels_msgs = []
157
- for feature_id, labels_list in labels_by_feature.items():
158
- feature = Feature.objects.using(self._state.db).get(id=feature_id)
159
- print_values = _print_values(labels_list, n=10)
160
- type_str = f": {feature.dtype}" if print_types else ""
161
- if to_dict:
162
- dictionary[feature.name] = (
163
- labels_list if len(labels_list) > 1 else labels_list[0]
164
- )
165
- labels_msgs.append(f" '{feature.name}'{type_str} = {print_values}")
166
- if len(labels_msgs) > 0:
167
- labels_msg = "\n".join(sorted(labels_msgs)) + "\n"
168
- msg += labels_msg
169
-
170
- # non-categorical feature values
171
- non_labels_msg = ""
172
- if self.id is not None and self.__class__ == Artifact or self.__class__ == Run:
173
- attr_name = "param" if print_params else "feature"
174
- _feature_values = (
175
- getattr(self, f"_{attr_name}_values")
176
- .values(f"{attr_name}__name", f"{attr_name}__dtype")
177
- .annotate(values=custom_aggregate("value", self._state.db))
178
- .order_by(f"{attr_name}__name")
179
- )
180
- if len(_feature_values) > 0:
181
- for fv in _feature_values:
182
- feature_name = fv[f"{attr_name}__name"]
183
- feature_dtype = fv[f"{attr_name}__dtype"]
184
- values = fv["values"]
185
- # TODO: understand why the below is necessary
186
- if not isinstance(values, list):
187
- values = [values]
188
- if to_dict:
189
- dictionary[feature_name] = values if len(values) > 1 else values[0]
190
- type_str = f": {feature_dtype}" if print_types else ""
191
- printed_values = (
192
- _print_values(values, n=10, quotes=False)
193
- if not feature_dtype.startswith("list")
194
- else values
195
- )
196
- non_labels_msg += f" '{feature_name}'{type_str} = {printed_values}\n"
197
- msg += non_labels_msg
198
-
199
- if msg != "":
200
- header = "Features" if not print_params else "Params"
201
- msg = f" {colors.italic(header)}\n" + msg
202
-
203
- # feature sets
204
- if not print_params:
205
- feature_set_msg = ""
206
- for slot, feature_set in get_feature_set_by_slot_(self).items():
207
- features = feature_set.members
208
- # features.first() is a lot slower than features[0] here
209
- name_field = get_name_field(features[0])
210
- feature_names = list(features.values_list(name_field, flat=True)[:20])
211
- type_str = f": {feature_set.registry}" if print_types else ""
212
- feature_set_msg += (
213
- f" '{slot}'{type_str} = {_print_values(feature_names)}\n"
214
- )
215
- if feature_set_msg:
216
- msg += f" {colors.italic('Feature sets')}\n"
217
- msg += feature_set_msg
218
- if to_dict:
219
- return dictionary
220
- else:
221
- return msg
222
-
223
-
224
- def parse_feature_sets_from_anndata(
225
- adata: AnnData,
226
- var_field: FieldAttr | None = None,
227
- obs_field: FieldAttr = Feature.name,
228
- mute: bool = False,
229
- organism: str | Record | None = None,
230
- ) -> dict:
231
- data_parse = adata
232
- if not isinstance(adata, AnnData): # is a path
233
- filepath = create_path(adata) # returns Path for local
234
- if not isinstance(filepath, LocalPathClasses):
235
- from lamindb.core.storage._backed_access import backed_access
236
-
237
- using_key = settings._using_key
238
- data_parse = backed_access(filepath, using_key=using_key)
239
- else:
240
- data_parse = ad.read_h5ad(filepath, backed="r")
241
- type = "float"
242
- else:
243
- type = (
244
- "float"
245
- if adata.X is None
246
- else convert_numpy_dtype_to_lamin_feature_type(adata.X.dtype)
247
- )
248
- feature_sets = {}
249
- if var_field is not None:
250
- logger.info("parsing feature names of X stored in slot 'var'")
251
- logger.indent = " "
252
- feature_set_var = FeatureSet.from_values(
253
- data_parse.var.index,
254
- var_field,
255
- type=type,
256
- mute=mute,
257
- organism=organism,
258
- raise_validation_error=False,
259
- )
260
- if feature_set_var is not None:
261
- feature_sets["var"] = feature_set_var
262
- logger.save(f"linked: {feature_set_var}")
263
- logger.indent = ""
264
- if feature_set_var is None:
265
- logger.warning("skip linking features to artifact in slot 'var'")
266
- if len(data_parse.obs.columns) > 0:
267
- logger.info("parsing feature names of slot 'obs'")
268
- logger.indent = " "
269
- feature_set_obs = FeatureSet.from_df(
270
- df=data_parse.obs,
271
- field=obs_field,
272
- mute=mute,
273
- organism=organism,
274
- )
275
- if feature_set_obs is not None:
276
- feature_sets["obs"] = feature_set_obs
277
- logger.save(f"linked: {feature_set_obs}")
278
- logger.indent = ""
279
- if feature_set_obs is None:
280
- logger.warning("skip linking features to artifact in slot 'obs'")
281
- return feature_sets
282
-
283
-
284
- def infer_feature_type_convert_json(
285
- value: Any, mute: bool = False, str_as_ulabel: bool = True
286
- ) -> tuple[str, Any]:
287
- if isinstance(value, bool):
288
- return FEATURE_TYPES["bool"], value
289
- elif isinstance(value, int):
290
- return FEATURE_TYPES["int"], value
291
- elif isinstance(value, float):
292
- return FEATURE_TYPES["float"], value
293
- elif isinstance(value, str):
294
- if str_as_ulabel:
295
- return FEATURE_TYPES["str"] + "[ULabel]", value
296
- else:
297
- return "str", value
298
- elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
299
- if isinstance(value, (pd.Series, np.ndarray)):
300
- return convert_numpy_dtype_to_lamin_feature_type(
301
- value.dtype, str_as_cat=str_as_ulabel
302
- ), list(value)
303
- if isinstance(value, dict):
304
- return "dict", value
305
- if len(value) > 0: # type: ignore
306
- first_element_type = type(next(iter(value)))
307
- if all(isinstance(elem, first_element_type) for elem in value):
308
- if first_element_type is bool:
309
- return f"list[{FEATURE_TYPES['bool']}]", value
310
- elif first_element_type is int:
311
- return f"list[{FEATURE_TYPES['int']}]", value
312
- elif first_element_type is float:
313
- return f"list[{FEATURE_TYPES['float']}]", value
314
- elif first_element_type is str:
315
- if str_as_ulabel:
316
- return FEATURE_TYPES["str"] + "[ULabel]", value
317
- else:
318
- return "list[str]", value
319
- elif first_element_type == Record:
320
- return (
321
- f"cat[{first_element_type.__get_name_with_schema__()}]",
322
- value,
323
- )
324
- elif isinstance(value, Record):
325
- return (f"cat[{value.__class__.__get_name_with_schema__()}]", value)
326
- if not mute:
327
- logger.warning(f"cannot infer feature type of: {value}, returning '?")
328
- return ("?", value)
329
-
330
-
331
- def __init__(self, host: Artifact | Collection | Run):
332
- self._host = host
333
- self._feature_set_by_slot_ = None
334
- self._accessor_by_registry_ = None
335
-
336
-
337
- def __repr__(self) -> str:
338
- return print_features(self._host, print_params=(self.__class__ == ParamManager)) # type: ignore
339
-
340
-
341
- def get_values(self) -> dict[str, Any]:
342
- """Get feature values as a dictionary."""
343
- return print_features(
344
- self._host, to_dict=True, print_params=(self.__class__ == ParamManager)
345
- ) # type: ignore
346
-
347
-
348
- def __getitem__(self, slot) -> QuerySet:
349
- if slot not in self._feature_set_by_slot:
350
- raise ValueError(
351
- f"No linked feature set for slot: {slot}\nDid you get validation"
352
- " warnings? Only features that match registered features get validated"
353
- " and linked."
354
- )
355
- feature_set = self._feature_set_by_slot[slot]
356
- orm_name = feature_set.registry
357
- return getattr(feature_set, self._accessor_by_registry[orm_name]).all()
358
-
359
-
360
- def filter_base(cls, **expression):
361
- if cls is FeatureManager:
362
- model = Feature
363
- value_model = FeatureValue
364
- else:
365
- model = Param
366
- value_model = ParamValue
367
- keys_normalized = [key.split("__")[0] for key in expression]
368
- validated = model.validate(keys_normalized, field="name", mute=True)
369
- if sum(validated) != len(keys_normalized):
370
- raise ValidationError(
371
- f"Some keys in the filter expression are not registered as features: {np.array(keys_normalized)[~validated]}"
372
- )
373
- new_expression = {}
374
- features = model.filter(name__in=keys_normalized).all().distinct()
375
- for key, value in expression.items():
376
- split_key = key.split("__")
377
- normalized_key = split_key[0]
378
- comparator = ""
379
- if len(split_key) == 2:
380
- comparator = f"__{split_key[1]}"
381
- feature = features.get(name=normalized_key)
382
- if not feature.dtype.startswith("cat"):
383
- expression = {"feature": feature, f"value{comparator}": value}
384
- feature_value = value_model.filter(**expression)
385
- new_expression["_feature_values__in"] = feature_value
386
- else:
387
- if isinstance(value, str):
388
- expression = {f"name{comparator}": value}
389
- label = ULabel.get(**expression)
390
- new_expression["ulabels"] = label
391
- else:
392
- raise NotImplementedError
393
- if cls == FeatureManager or cls == ParamManagerArtifact:
394
- return Artifact.filter(**new_expression)
395
- # might renable something similar in the future
396
- # elif cls == FeatureManagerCollection:
397
- # return Collection.filter(**new_expression)
398
- elif cls == ParamManagerRun:
399
- return Run.filter(**new_expression)
400
-
401
-
402
- @classmethod # type: ignore
403
- def filter(cls, **expression) -> QuerySet:
404
- """Query artifacts by features."""
405
- return filter_base(cls, **expression)
406
-
407
-
408
- @classmethod # type: ignore
409
- def get(cls, **expression) -> Record:
410
- """Query a single artifact by feature."""
411
- return filter_base(cls, **expression).one()
412
-
413
-
414
- @property # type: ignore
415
- def _feature_set_by_slot(self):
416
- """Feature sets by slot."""
417
- if self._feature_set_by_slot_ is None:
418
- self._feature_set_by_slot_ = get_feature_set_by_slot_(self._host)
419
- return self._feature_set_by_slot_
420
-
421
-
422
- @property # type: ignore
423
- def _accessor_by_registry(self):
424
- """Accessor by ORM."""
425
- if self._accessor_by_registry_ is None:
426
- self._accessor_by_registry_ = get_accessor_by_registry_(self._host)
427
- return self._accessor_by_registry_
428
-
429
-
430
- def _add_values(
431
- self,
432
- values: dict[str, str | int | float | bool],
433
- feature_param_field: FieldAttr,
434
- str_as_ulabel: bool = True,
435
- ) -> None:
436
- """Curate artifact with features & values.
437
-
438
- Args:
439
- values: A dictionary of keys (features) & values (labels, numbers, booleans).
440
- feature_param_field: The field of a reference registry to map keys of the
441
- dictionary.
442
- """
443
- # rename to distinguish from the values inside the dict
444
- features_values = values
445
- keys = features_values.keys()
446
- if isinstance(keys, DICT_KEYS_TYPE):
447
- keys = list(keys) # type: ignore
448
- # deal with other cases later
449
- assert all(isinstance(key, str) for key in keys) # noqa: S101
450
- registry = feature_param_field.field.model
451
- is_param = registry == Param
452
- model = Param if is_param else Feature
453
- value_model = ParamValue if is_param else FeatureValue
454
- model_name = "Param" if is_param else "Feature"
455
- if is_param:
456
- if self._host.__class__ == Artifact:
457
- if self._host.type != "model":
458
- raise ValidationError("Can only set params for model-like artifacts.")
459
- else:
460
- if self._host.__class__ == Artifact:
461
- if self._host.type != "dataset" and self._host.type is not None:
462
- raise ValidationError(
463
- "Can only set features for dataset-like artifacts."
464
- )
465
- validated = registry.validate(keys, field=feature_param_field, mute=True)
466
- keys_array = np.array(keys)
467
- validated_keys = keys_array[validated]
468
- if validated.sum() != len(keys):
469
- not_validated_keys = keys_array[~validated]
470
- hint = "\n".join(
471
- [
472
- f" ln.{model_name}(name='{key}', dtype='{infer_feature_type_convert_json(features_values[key], str_as_ulabel=str_as_ulabel)[0]}').save()"
473
- for key in not_validated_keys
474
- ]
475
- )
476
- msg = (
477
- f"These keys could not be validated: {not_validated_keys.tolist()}\n"
478
- f"Here is how to create a {model_name.lower()}:\n\n{hint}"
479
- )
480
- raise ValidationError(msg)
481
- registry.from_values(
482
- validated_keys,
483
- field=feature_param_field,
484
- )
485
- # figure out which of the values go where
486
- features_labels = defaultdict(list)
487
- _feature_values = []
488
- not_validated_values = []
489
- for key, value in features_values.items():
490
- feature = model.get(name=key)
491
- inferred_type, converted_value = infer_feature_type_convert_json(
492
- value,
493
- mute=True,
494
- str_as_ulabel=str_as_ulabel,
495
- )
496
- if feature.dtype == "number":
497
- if inferred_type not in {"int", "float"}:
498
- raise TypeError(
499
- f"Value for feature '{key}' with type {feature.dtype} must be a number"
500
- )
501
- elif feature.dtype.startswith("cat"):
502
- if inferred_type != "?":
503
- if not (inferred_type.startswith("cat") or isinstance(value, Record)):
504
- raise TypeError(
505
- f"Value for feature '{key}' with type '{feature.dtype}' must be a string or record."
506
- )
507
- elif not inferred_type == feature.dtype:
508
- raise ValidationError(
509
- f"Expected dtype for '{key}' is '{feature.dtype}', got '{inferred_type}'"
510
- )
511
- if not feature.dtype.startswith("cat"):
512
- # can remove the query once we have the unique constraint
513
- filter_kwargs = {model_name.lower(): feature, "value": converted_value}
514
- feature_value = value_model.filter(**filter_kwargs).one_or_none()
515
- if feature_value is None:
516
- feature_value = value_model(**filter_kwargs)
517
- _feature_values.append(feature_value)
518
- else:
519
- if isinstance(value, Record) or (
520
- isinstance(value, Iterable) and isinstance(next(iter(value)), Record)
521
- ):
522
- if isinstance(value, Record):
523
- label_records = [value]
524
- else:
525
- label_records = value # type: ignore
526
- for record in label_records:
527
- if record._state.adding:
528
- raise ValidationError(
529
- f"Please save {record} before annotation."
530
- )
531
- features_labels[record.__class__.__get_name_with_schema__()].append(
532
- (feature, record)
533
- )
534
- else:
535
- if isinstance(value, str):
536
- values = [value] # type: ignore
537
- else:
538
- values = value # type: ignore
539
- if "ULabel" not in feature.dtype:
540
- feature.dtype += "[ULabel]"
541
- feature.save()
542
- validated = ULabel.validate(values, field="name", mute=True)
543
- values_array = np.array(values)
544
- validated_values = values_array[validated]
545
- if validated.sum() != len(values):
546
- not_validated_values += values_array[~validated].tolist()
547
- label_records = ULabel.from_values(validated_values, field="name")
548
- features_labels["ULabel"] += [
549
- (feature, label_record) for label_record in label_records
550
- ]
551
- if not_validated_values:
552
- hint = (
553
- f" ulabels = ln.ULabel.from_values({not_validated_values}, create=True)\n"
554
- f" ln.save(ulabels)"
555
- )
556
- msg = (
557
- f"These values could not be validated: {not_validated_values}\n"
558
- f"Here is how to create ulabels for them:\n\n{hint}"
559
- )
560
- raise ValidationError(msg)
561
- # bulk add all links to ArtifactULabel
562
- if features_labels:
563
- if list(features_labels.keys()) != ["ULabel"]:
564
- related_names = dict_related_model_to_related_name(self._host.__class__)
565
- else:
566
- related_names = {"ULabel": "ulabels"}
567
- for class_name, registry_features_labels in features_labels.items():
568
- related_name = related_names[class_name] # e.g., "ulabels"
569
- LinkORM = getattr(self._host, related_name).through
570
- field_name = f"{get_link_attr(LinkORM, self._host)}_id" # e.g., ulabel_id
571
- links = [
572
- LinkORM(
573
- **{
574
- "artifact_id": self._host.id,
575
- "feature_id": feature.id,
576
- field_name: label.id,
577
- }
578
- )
579
- for (feature, label) in registry_features_labels
580
- ]
581
- # a link might already exist
582
- try:
583
- save(links, ignore_conflicts=False)
584
- except Exception:
585
- save(links, ignore_conflicts=True)
586
- # now deal with links that were previously saved without a feature_id
587
- links_saved = LinkORM.filter(
588
- **{
589
- "artifact_id": self._host.id,
590
- f"{field_name}__in": [
591
- l.id for _, l in registry_features_labels
592
- ],
593
- }
594
- )
595
- for link in links_saved.all():
596
- # TODO: also check for inconsistent features
597
- if link.feature_id is None:
598
- link.feature_id = [
599
- f.id
600
- for f, l in registry_features_labels
601
- if l.id == getattr(link, field_name)
602
- ][0]
603
- link.save()
604
- if _feature_values:
605
- save(_feature_values)
606
- if is_param:
607
- LinkORM = self._host._param_values.through
608
- valuefield_id = "paramvalue_id"
609
- else:
610
- LinkORM = self._host._feature_values.through
611
- valuefield_id = "featurevalue_id"
612
- links = [
613
- LinkORM(
614
- **{
615
- f"{self._host.__class__.__get_name_with_schema__().lower()}_id": self._host.id,
616
- valuefield_id: feature_value.id,
617
- }
618
- )
619
- for feature_value in _feature_values
620
- ]
621
- # a link might already exist, to avoid raising a unique constraint
622
- # error, ignore_conflicts
623
- save(links, ignore_conflicts=True)
624
-
625
-
626
- def add_values_features(
627
- self,
628
- values: dict[str, str | int | float | bool],
629
- feature_field: FieldAttr = Feature.name,
630
- str_as_ulabel: bool = True,
631
- ) -> None:
632
- """Curate artifact with features & values.
633
-
634
- Args:
635
- values: A dictionary of keys (features) & values (labels, numbers, booleans).
636
- feature_field: The field of a reference registry to map keys of the
637
- dictionary.
638
- str_as_ulabel: Whether to interpret string values as ulabels.
639
- """
640
- _add_values(self, values, feature_field, str_as_ulabel=str_as_ulabel)
641
-
642
-
643
- def add_values_params(
644
- self,
645
- values: dict[str, str | int | float | bool],
646
- ) -> None:
647
- """Curate artifact with features & values.
648
-
649
- Args:
650
- values: A dictionary of keys (features) & values (labels, numbers, booleans).
651
- """
652
- _add_values(self, values, Param.name, str_as_ulabel=False)
653
-
654
-
655
- def add_feature_set(self, feature_set: FeatureSet, slot: str) -> None:
656
- """Curate artifact with a feature set.
657
-
658
- Args:
659
- feature_set: `FeatureSet` A feature set record.
660
- slot: `str` The slot that marks where the feature set is stored in
661
- the artifact.
662
- """
663
- if self._host._state.adding:
664
- raise ValueError(
665
- "Please save the artifact or collection before adding a feature set!"
666
- )
667
- host_db = self._host._state.db
668
- feature_set.save(using=host_db)
669
- host_id_field = get_host_id_field(self._host)
670
- kwargs = {
671
- host_id_field: self._host.id,
672
- "featureset": feature_set,
673
- "slot": slot,
674
- }
675
- link_record = (
676
- self._host.feature_sets.through.objects.using(host_db)
677
- .filter(**kwargs)
678
- .one_or_none()
679
- )
680
- if link_record is None:
681
- self._host.feature_sets.through(**kwargs).save(using=host_db)
682
- if slot in self._feature_set_by_slot:
683
- logger.debug(f"replaced existing {slot} feature set")
684
- self._feature_set_by_slot_[slot] = feature_set # type: ignore
685
-
686
-
687
- def _add_set_from_df(
688
- self, field: FieldAttr = Feature.name, organism: str | None = None
689
- ):
690
- """Add feature set corresponding to column names of DataFrame."""
691
- if isinstance(self._host, Artifact):
692
- assert self._host._accessor == "DataFrame" # noqa: S101
693
- else:
694
- # Collection
695
- assert self._host.artifact._accessor == "DataFrame" # noqa: S101
696
-
697
- # parse and register features
698
- registry = field.field.model
699
- df = self._host.load()
700
- features = registry.from_values(df.columns, field=field, organism=organism)
701
- if len(features) == 0:
702
- logger.error(
703
- "no validated features found in DataFrame! please register features first!"
704
- )
705
- return
706
-
707
- # create and link feature sets
708
- feature_set = FeatureSet(features=features)
709
- feature_sets = {"columns": feature_set}
710
- self._host._feature_sets = feature_sets
711
- self._host.save()
712
-
713
-
714
- def _add_set_from_anndata(
715
- self,
716
- var_field: FieldAttr,
717
- obs_field: FieldAttr | None = Feature.name,
718
- mute: bool = False,
719
- organism: str | Record | None = None,
720
- ):
721
- """Add features from AnnData."""
722
- if isinstance(self._host, Artifact):
723
- assert self._host._accessor == "AnnData" # noqa: S101
724
- else:
725
- raise NotImplementedError()
726
-
727
- # parse and register features
728
- adata = self._host.load()
729
- feature_sets = parse_feature_sets_from_anndata(
730
- adata,
731
- var_field=var_field,
732
- obs_field=obs_field,
733
- mute=mute,
734
- organism=organism,
735
- )
736
-
737
- # link feature sets
738
- self._host._feature_sets = feature_sets
739
- self._host.save()
740
-
741
-
742
- def _add_set_from_mudata(
743
- self,
744
- var_fields: dict[str, FieldAttr],
745
- obs_fields: dict[str, FieldAttr] = None,
746
- mute: bool = False,
747
- organism: str | Record | None = None,
748
- ):
749
- """Add features from MuData."""
750
- if obs_fields is None:
751
- obs_fields = {}
752
- if isinstance(self._host, Artifact):
753
- assert self._host._accessor == "MuData" # noqa: S101
754
- else:
755
- raise NotImplementedError()
756
-
757
- # parse and register features
758
- mdata = self._host.load()
759
- feature_sets = {}
760
- obs_features = Feature.from_values(mdata.obs.columns)
761
- if len(obs_features) > 0:
762
- feature_sets["obs"] = FeatureSet(features=obs_features)
763
- for modality, field in var_fields.items():
764
- modality_fs = parse_feature_sets_from_anndata(
765
- mdata[modality],
766
- var_field=field,
767
- obs_field=obs_fields.get(modality, Feature.name),
768
- mute=mute,
769
- organism=organism,
770
- )
771
- for k, v in modality_fs.items():
772
- feature_sets[f"['{modality}'].{k}"] = v
773
-
774
- def unify_feature_sets_by_hash(feature_sets):
775
- unique_values = {}
776
-
777
- for key, value in feature_sets.items():
778
- value_hash = value.hash # Assuming each value has a .hash attribute
779
- if value_hash in unique_values:
780
- feature_sets[key] = unique_values[value_hash]
781
- else:
782
- unique_values[value_hash] = value
783
-
784
- return feature_sets
785
-
786
- # link feature sets
787
- self._host._feature_sets = unify_feature_sets_by_hash(feature_sets)
788
- self._host.save()
789
-
790
-
791
- def _add_from(self, data: Artifact | Collection, transfer_logs: dict = None):
792
- """Transfer features from a artifact or collection."""
793
- # This only covers feature sets
794
- if transfer_logs is None:
795
- transfer_logs = {"mapped": [], "transferred": []}
796
- using_key = settings._using_key
797
- for slot, feature_set in data.features._feature_set_by_slot.items():
798
- members = feature_set.members
799
- if len(members) == 0:
800
- continue
801
- registry = members[0].__class__
802
- # note here the features are transferred based on an unique field
803
- field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
804
- if hasattr(registry, "_ontology_id_field"):
805
- field = registry._ontology_id_field
806
- # this will be e.g. be a list of ontology_ids or uids
807
- member_uids = list(members.values_list(field, flat=True))
808
- # create records from ontology_id
809
- if hasattr(registry, "_ontology_id_field") and len(member_uids) > 0:
810
- # create from bionty
811
- members_records = registry.from_values(member_uids, field=field)
812
- save([r for r in members_records if r._state.adding])
813
- validated = registry.validate(member_uids, field=field, mute=True)
814
- new_members_uids = list(compress(member_uids, ~validated))
815
- new_members = members.filter(**{f"{field}__in": new_members_uids}).all()
816
- n_new_members = len(new_members)
817
- if n_new_members > 0:
818
- # transfer foreign keys needs to be run before transfer to default db
819
- transfer_fk_to_default_db_bulk(
820
- new_members, using_key, transfer_logs=transfer_logs
821
- )
822
- for feature in new_members:
823
- # not calling save=True here as in labels, because want to
824
- # bulk save below
825
- # transfer_fk is set to False because they are already transferred
826
- # in the previous step transfer_fk_to_default_db_bulk
827
- transfer_to_default_db(
828
- feature, using_key, transfer_fk=False, transfer_logs=transfer_logs
829
- )
830
- logger.info(f"saving {n_new_members} new {registry.__name__} records")
831
- save(new_members)
832
-
833
- # create a new feature set from feature values using the same uid
834
- feature_set_self = FeatureSet.from_values(
835
- member_uids, field=getattr(registry, field)
836
- )
837
- if feature_set_self is None:
838
- if hasattr(registry, "organism_id"):
839
- logger.warning(
840
- f"FeatureSet is not transferred, check if organism is set correctly: {feature_set}"
841
- )
842
- continue
843
- # make sure the uid matches if featureset is composed of same features
844
- if feature_set_self.hash == feature_set.hash:
845
- feature_set_self.uid = feature_set.uid
846
- logger.info(f"saving {slot} featureset: {feature_set_self}")
847
- self._host.features.add_feature_set(feature_set_self, slot)
848
-
849
-
850
- FeatureManager.__init__ = __init__
851
- ParamManager.__init__ = __init__
852
- FeatureManager.__repr__ = __repr__
853
- ParamManager.__repr__ = __repr__
854
- FeatureManager.__getitem__ = __getitem__
855
- FeatureManager.get_values = get_values
856
- FeatureManager._feature_set_by_slot = _feature_set_by_slot
857
- FeatureManager._accessor_by_registry = _accessor_by_registry
858
- FeatureManager.add_values = add_values_features
859
- FeatureManager.add_feature_set = add_feature_set
860
- FeatureManager._add_set_from_df = _add_set_from_df
861
- FeatureManager._add_set_from_anndata = _add_set_from_anndata
862
- FeatureManager._add_set_from_mudata = _add_set_from_mudata
863
- FeatureManager._add_from = _add_from
864
- FeatureManager.filter = filter
865
- FeatureManager.get = get
866
- ParamManager.add_values = add_values_params
867
- ParamManager.get_values = get_values
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from collections.abc import Iterable
5
+ from itertools import compress
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import pandas as pd
11
+ from anndata import AnnData
12
+ from django.contrib.postgres.aggregates import ArrayAgg
13
+ from django.db import connections
14
+ from django.db.models import Aggregate
15
+ from lamin_utils import colors, logger
16
+ from lamindb_setup.core.upath import create_path
17
+ from lnschema_core.models import (
18
+ Artifact,
19
+ Collection,
20
+ Feature,
21
+ FeatureManager,
22
+ FeatureValue,
23
+ LinkORM,
24
+ Param,
25
+ ParamManager,
26
+ ParamManagerArtifact,
27
+ ParamManagerRun,
28
+ ParamValue,
29
+ Record,
30
+ Run,
31
+ ULabel,
32
+ )
33
+
34
+ from lamindb._feature import FEATURE_TYPES, convert_numpy_dtype_to_lamin_feature_type
35
+ from lamindb._feature_set import DICT_KEYS_TYPE, FeatureSet
36
+ from lamindb._record import (
37
+ REGISTRY_UNIQUE_FIELD,
38
+ get_name_field,
39
+ transfer_fk_to_default_db_bulk,
40
+ transfer_to_default_db,
41
+ )
42
+ from lamindb._save import save
43
+ from lamindb.core.exceptions import ValidationError
44
+ from lamindb.core.storage import LocalPathClasses
45
+
46
+ from ._label_manager import get_labels_as_dict
47
+ from ._settings import settings
48
+ from .schema import (
49
+ dict_related_model_to_related_name,
50
+ )
51
+
52
+ if TYPE_CHECKING:
53
+ from lnschema_core.types import FieldAttr
54
+
55
+ from lamindb._query_set import QuerySet
56
+
57
+
58
+ def get_host_id_field(host: Artifact | Collection) -> str:
59
+ if isinstance(host, Artifact):
60
+ host_id_field = "artifact_id"
61
+ else:
62
+ host_id_field = "collection_id"
63
+ return host_id_field
64
+
65
+
66
+ def get_accessor_by_registry_(host: Artifact | Collection) -> dict:
67
+ dictionary = {
68
+ field.related_model.__get_name_with_schema__(): field.name
69
+ for field in host._meta.related_objects
70
+ }
71
+ dictionary["Feature"] = "features"
72
+ dictionary["ULabel"] = "ulabels"
73
+ return dictionary
74
+
75
+
76
+ def get_feature_set_by_slot_(host) -> dict:
77
+ # if the host is not yet saved
78
+ if host._state.adding:
79
+ if hasattr(host, "_feature_sets"):
80
+ return host._feature_sets
81
+ else:
82
+ return {}
83
+ host_db = host._state.db
84
+ host_id_field = get_host_id_field(host)
85
+ kwargs = {host_id_field: host.id}
86
+ # otherwise, we need a query
87
+ links_feature_set = (
88
+ host.feature_sets.through.objects.using(host_db)
89
+ .filter(**kwargs)
90
+ .select_related("featureset")
91
+ )
92
+ return {fsl.slot: fsl.featureset for fsl in links_feature_set}
93
+
94
+
95
+ def get_label_links(
96
+ host: Artifact | Collection, registry: str, feature: Feature
97
+ ) -> QuerySet:
98
+ host_id_field = get_host_id_field(host)
99
+ kwargs = {host_id_field: host.id, "feature_id": feature.id}
100
+ link_records = (
101
+ getattr(host, host.features._accessor_by_registry[registry])
102
+ .through.objects.using(host._state.db)
103
+ .filter(**kwargs)
104
+ )
105
+ return link_records
106
+
107
+
108
+ def get_feature_set_links(host: Artifact | Collection) -> QuerySet:
109
+ host_id_field = get_host_id_field(host)
110
+ kwargs = {host_id_field: host.id}
111
+ links_feature_set = host.feature_sets.through.objects.filter(**kwargs)
112
+ return links_feature_set
113
+
114
+
115
+ def get_link_attr(link: LinkORM | type[LinkORM], data: Artifact | Collection) -> str:
116
+ link_model_name = link.__class__.__name__
117
+ if link_model_name in {"Registry", "ModelBase"}: # we passed the type of the link
118
+ link_model_name = link.__name__
119
+ return link_model_name.replace(data.__class__.__name__, "").lower()
120
+
121
+
122
+ # Custom aggregation for SQLite
123
+ class GroupConcat(Aggregate):
124
+ function = "GROUP_CONCAT"
125
+ template = '%(function)s(%(expressions)s, ", ")'
126
+
127
+
128
+ def custom_aggregate(field, using: str):
129
+ if connections[using].vendor == "postgresql":
130
+ return ArrayAgg(field)
131
+ else:
132
+ return GroupConcat(field)
133
+
134
+
135
+ def print_features(
136
+ self: Artifact | Collection,
137
+ print_types: bool = False,
138
+ to_dict: bool = False,
139
+ print_params: bool = False,
140
+ ) -> str | dict[str, Any]:
141
+ from lamindb._from_values import _print_values
142
+
143
+ msg = ""
144
+ dictionary = {}
145
+ # categorical feature values
146
+ if not print_params:
147
+ labels_msg = ""
148
+ labels_by_feature = defaultdict(list)
149
+ for _, (_, links) in get_labels_as_dict(self, links=True).items():
150
+ for link in links:
151
+ if link.feature_id is not None:
152
+ link_attr = get_link_attr(link, self)
153
+ labels_by_feature[link.feature_id].append(
154
+ getattr(link, link_attr).name
155
+ )
156
+ labels_msgs = []
157
+ for feature_id, labels_list in labels_by_feature.items():
158
+ feature = Feature.objects.using(self._state.db).get(id=feature_id)
159
+ print_values = _print_values(labels_list, n=10)
160
+ type_str = f": {feature.dtype}" if print_types else ""
161
+ if to_dict:
162
+ dictionary[feature.name] = (
163
+ labels_list if len(labels_list) > 1 else labels_list[0]
164
+ )
165
+ labels_msgs.append(f" '{feature.name}'{type_str} = {print_values}")
166
+ if len(labels_msgs) > 0:
167
+ labels_msg = "\n".join(sorted(labels_msgs)) + "\n"
168
+ msg += labels_msg
169
+
170
+ # non-categorical feature values
171
+ non_labels_msg = ""
172
+ if self.id is not None and self.__class__ == Artifact or self.__class__ == Run:
173
+ attr_name = "param" if print_params else "feature"
174
+ _feature_values = (
175
+ getattr(self, f"_{attr_name}_values")
176
+ .values(f"{attr_name}__name", f"{attr_name}__dtype")
177
+ .annotate(values=custom_aggregate("value", self._state.db))
178
+ .order_by(f"{attr_name}__name")
179
+ )
180
+ if len(_feature_values) > 0:
181
+ for fv in _feature_values:
182
+ feature_name = fv[f"{attr_name}__name"]
183
+ feature_dtype = fv[f"{attr_name}__dtype"]
184
+ values = fv["values"]
185
+ # TODO: understand why the below is necessary
186
+ if not isinstance(values, list):
187
+ values = [values]
188
+ if to_dict:
189
+ dictionary[feature_name] = values if len(values) > 1 else values[0]
190
+ type_str = f": {feature_dtype}" if print_types else ""
191
+ printed_values = (
192
+ _print_values(values, n=10, quotes=False)
193
+ if not feature_dtype.startswith("list")
194
+ else values
195
+ )
196
+ non_labels_msg += f" '{feature_name}'{type_str} = {printed_values}\n"
197
+ msg += non_labels_msg
198
+
199
+ if msg != "":
200
+ header = "Features" if not print_params else "Params"
201
+ msg = f" {colors.italic(header)}\n" + msg
202
+
203
+ # feature sets
204
+ if not print_params:
205
+ feature_set_msg = ""
206
+ for slot, feature_set in get_feature_set_by_slot_(self).items():
207
+ features = feature_set.members
208
+ # features.first() is a lot slower than features[0] here
209
+ name_field = get_name_field(features[0])
210
+ feature_names = list(features.values_list(name_field, flat=True)[:20])
211
+ type_str = f": {feature_set.registry}" if print_types else ""
212
+ feature_set_msg += (
213
+ f" '{slot}'{type_str} = {_print_values(feature_names)}\n"
214
+ )
215
+ if feature_set_msg:
216
+ msg += f" {colors.italic('Feature sets')}\n"
217
+ msg += feature_set_msg
218
+ if to_dict:
219
+ return dictionary
220
+ else:
221
+ return msg
222
+
223
+
224
+ def parse_feature_sets_from_anndata(
225
+ adata: AnnData,
226
+ var_field: FieldAttr | None = None,
227
+ obs_field: FieldAttr = Feature.name,
228
+ mute: bool = False,
229
+ organism: str | Record | None = None,
230
+ ) -> dict:
231
+ data_parse = adata
232
+ if not isinstance(adata, AnnData): # is a path
233
+ filepath = create_path(adata) # returns Path for local
234
+ if not isinstance(filepath, LocalPathClasses):
235
+ from lamindb.core.storage._backed_access import backed_access
236
+
237
+ using_key = settings._using_key
238
+ data_parse = backed_access(filepath, using_key=using_key)
239
+ else:
240
+ data_parse = ad.read_h5ad(filepath, backed="r")
241
+ type = "float"
242
+ else:
243
+ type = (
244
+ "float"
245
+ if adata.X is None
246
+ else convert_numpy_dtype_to_lamin_feature_type(adata.X.dtype)
247
+ )
248
+ feature_sets = {}
249
+ if var_field is not None:
250
+ logger.info("parsing feature names of X stored in slot 'var'")
251
+ logger.indent = " "
252
+ feature_set_var = FeatureSet.from_values(
253
+ data_parse.var.index,
254
+ var_field,
255
+ type=type,
256
+ mute=mute,
257
+ organism=organism,
258
+ raise_validation_error=False,
259
+ )
260
+ if feature_set_var is not None:
261
+ feature_sets["var"] = feature_set_var
262
+ logger.save(f"linked: {feature_set_var}")
263
+ logger.indent = ""
264
+ if feature_set_var is None:
265
+ logger.warning("skip linking features to artifact in slot 'var'")
266
+ if len(data_parse.obs.columns) > 0:
267
+ logger.info("parsing feature names of slot 'obs'")
268
+ logger.indent = " "
269
+ feature_set_obs = FeatureSet.from_df(
270
+ df=data_parse.obs,
271
+ field=obs_field,
272
+ mute=mute,
273
+ organism=organism,
274
+ )
275
+ if feature_set_obs is not None:
276
+ feature_sets["obs"] = feature_set_obs
277
+ logger.save(f"linked: {feature_set_obs}")
278
+ logger.indent = ""
279
+ if feature_set_obs is None:
280
+ logger.warning("skip linking features to artifact in slot 'obs'")
281
+ return feature_sets
282
+
283
+
284
+ def infer_feature_type_convert_json(
285
+ value: Any, mute: bool = False, str_as_ulabel: bool = True
286
+ ) -> tuple[str, Any]:
287
+ if isinstance(value, bool):
288
+ return FEATURE_TYPES["bool"], value
289
+ elif isinstance(value, int):
290
+ return FEATURE_TYPES["int"], value
291
+ elif isinstance(value, float):
292
+ return FEATURE_TYPES["float"], value
293
+ elif isinstance(value, str):
294
+ if str_as_ulabel:
295
+ return FEATURE_TYPES["str"] + "[ULabel]", value
296
+ else:
297
+ return "str", value
298
+ elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
299
+ if isinstance(value, (pd.Series, np.ndarray)):
300
+ return convert_numpy_dtype_to_lamin_feature_type(
301
+ value.dtype, str_as_cat=str_as_ulabel
302
+ ), list(value)
303
+ if isinstance(value, dict):
304
+ return "dict", value
305
+ if len(value) > 0: # type: ignore
306
+ first_element_type = type(next(iter(value)))
307
+ if all(isinstance(elem, first_element_type) for elem in value):
308
+ if first_element_type is bool:
309
+ return f"list[{FEATURE_TYPES['bool']}]", value
310
+ elif first_element_type is int:
311
+ return f"list[{FEATURE_TYPES['int']}]", value
312
+ elif first_element_type is float:
313
+ return f"list[{FEATURE_TYPES['float']}]", value
314
+ elif first_element_type is str:
315
+ if str_as_ulabel:
316
+ return FEATURE_TYPES["str"] + "[ULabel]", value
317
+ else:
318
+ return "list[str]", value
319
+ elif first_element_type == Record:
320
+ return (
321
+ f"cat[{first_element_type.__get_name_with_schema__()}]",
322
+ value,
323
+ )
324
+ elif isinstance(value, Record):
325
+ return (f"cat[{value.__class__.__get_name_with_schema__()}]", value)
326
+ if not mute:
327
+ logger.warning(f"cannot infer feature type of: {value}, returning '?")
328
+ return ("?", value)
329
+
330
+
331
+ def __init__(self, host: Artifact | Collection | Run):
332
+ self._host = host
333
+ self._feature_set_by_slot_ = None
334
+ self._accessor_by_registry_ = None
335
+
336
+
337
+ def __repr__(self) -> str:
338
+ return print_features(self._host, print_params=(self.__class__ == ParamManager)) # type: ignore
339
+
340
+
341
+ def get_values(self) -> dict[str, Any]:
342
+ """Get feature values as a dictionary."""
343
+ return print_features(
344
+ self._host, to_dict=True, print_params=(self.__class__ == ParamManager)
345
+ ) # type: ignore
346
+
347
+
348
+ def __getitem__(self, slot) -> QuerySet:
349
+ if slot not in self._feature_set_by_slot:
350
+ raise ValueError(
351
+ f"No linked feature set for slot: {slot}\nDid you get validation"
352
+ " warnings? Only features that match registered features get validated"
353
+ " and linked."
354
+ )
355
+ feature_set = self._feature_set_by_slot[slot]
356
+ orm_name = feature_set.registry
357
+ return getattr(feature_set, self._accessor_by_registry[orm_name]).all()
358
+
359
+
360
+ def filter_base(cls, **expression):
361
+ if cls is FeatureManager:
362
+ model = Feature
363
+ value_model = FeatureValue
364
+ else:
365
+ model = Param
366
+ value_model = ParamValue
367
+ keys_normalized = [key.split("__")[0] for key in expression]
368
+ validated = model.validate(keys_normalized, field="name", mute=True)
369
+ if sum(validated) != len(keys_normalized):
370
+ raise ValidationError(
371
+ f"Some keys in the filter expression are not registered as features: {np.array(keys_normalized)[~validated]}"
372
+ )
373
+ new_expression = {}
374
+ features = model.filter(name__in=keys_normalized).all().distinct()
375
+ for key, value in expression.items():
376
+ split_key = key.split("__")
377
+ normalized_key = split_key[0]
378
+ comparator = ""
379
+ if len(split_key) == 2:
380
+ comparator = f"__{split_key[1]}"
381
+ feature = features.get(name=normalized_key)
382
+ if not feature.dtype.startswith("cat"):
383
+ expression = {"feature": feature, f"value{comparator}": value}
384
+ feature_value = value_model.filter(**expression)
385
+ new_expression["_feature_values__in"] = feature_value
386
+ else:
387
+ if isinstance(value, str):
388
+ expression = {f"name{comparator}": value}
389
+ label = ULabel.get(**expression)
390
+ new_expression["ulabels"] = label
391
+ else:
392
+ raise NotImplementedError
393
+ if cls == FeatureManager or cls == ParamManagerArtifact:
394
+ return Artifact.filter(**new_expression)
395
+ # might renable something similar in the future
396
+ # elif cls == FeatureManagerCollection:
397
+ # return Collection.filter(**new_expression)
398
+ elif cls == ParamManagerRun:
399
+ return Run.filter(**new_expression)
400
+
401
+
402
+ @classmethod # type: ignore
403
+ def filter(cls, **expression) -> QuerySet:
404
+ """Query artifacts by features."""
405
+ return filter_base(cls, **expression)
406
+
407
+
408
+ @classmethod # type: ignore
409
+ def get(cls, **expression) -> Record:
410
+ """Query a single artifact by feature."""
411
+ return filter_base(cls, **expression).one()
412
+
413
+
414
+ @property # type: ignore
415
+ def _feature_set_by_slot(self):
416
+ """Feature sets by slot."""
417
+ if self._feature_set_by_slot_ is None:
418
+ self._feature_set_by_slot_ = get_feature_set_by_slot_(self._host)
419
+ return self._feature_set_by_slot_
420
+
421
+
422
+ @property # type: ignore
423
+ def _accessor_by_registry(self):
424
+ """Accessor by ORM."""
425
+ if self._accessor_by_registry_ is None:
426
+ self._accessor_by_registry_ = get_accessor_by_registry_(self._host)
427
+ return self._accessor_by_registry_
428
+
429
+
430
+ def _add_values(
431
+ self,
432
+ values: dict[str, str | int | float | bool],
433
+ feature_param_field: FieldAttr,
434
+ str_as_ulabel: bool = True,
435
+ ) -> None:
436
+ """Curate artifact with features & values.
437
+
438
+ Args:
439
+ values: A dictionary of keys (features) & values (labels, numbers, booleans).
440
+ feature_param_field: The field of a reference registry to map keys of the
441
+ dictionary.
442
+ """
443
+ # rename to distinguish from the values inside the dict
444
+ features_values = values
445
+ keys = features_values.keys()
446
+ if isinstance(keys, DICT_KEYS_TYPE):
447
+ keys = list(keys) # type: ignore
448
+ # deal with other cases later
449
+ assert all(isinstance(key, str) for key in keys) # noqa: S101
450
+ registry = feature_param_field.field.model
451
+ is_param = registry == Param
452
+ model = Param if is_param else Feature
453
+ value_model = ParamValue if is_param else FeatureValue
454
+ model_name = "Param" if is_param else "Feature"
455
+ if is_param:
456
+ if self._host.__class__ == Artifact:
457
+ if self._host.type != "model":
458
+ raise ValidationError("Can only set params for model-like artifacts.")
459
+ else:
460
+ if self._host.__class__ == Artifact:
461
+ if self._host.type != "dataset" and self._host.type is not None:
462
+ raise ValidationError(
463
+ "Can only set features for dataset-like artifacts."
464
+ )
465
+ validated = registry.validate(keys, field=feature_param_field, mute=True)
466
+ keys_array = np.array(keys)
467
+ validated_keys = keys_array[validated]
468
+ if validated.sum() != len(keys):
469
+ not_validated_keys = keys_array[~validated]
470
+ hint = "\n".join(
471
+ [
472
+ f" ln.{model_name}(name='{key}', dtype='{infer_feature_type_convert_json(features_values[key], str_as_ulabel=str_as_ulabel)[0]}').save()"
473
+ for key in not_validated_keys
474
+ ]
475
+ )
476
+ msg = (
477
+ f"These keys could not be validated: {not_validated_keys.tolist()}\n"
478
+ f"Here is how to create a {model_name.lower()}:\n\n{hint}"
479
+ )
480
+ raise ValidationError(msg)
481
+ registry.from_values(
482
+ validated_keys,
483
+ field=feature_param_field,
484
+ )
485
+ # figure out which of the values go where
486
+ features_labels = defaultdict(list)
487
+ _feature_values = []
488
+ not_validated_values = []
489
+ for key, value in features_values.items():
490
+ feature = model.get(name=key)
491
+ inferred_type, converted_value = infer_feature_type_convert_json(
492
+ value,
493
+ mute=True,
494
+ str_as_ulabel=str_as_ulabel,
495
+ )
496
+ if feature.dtype == "number":
497
+ if inferred_type not in {"int", "float"}:
498
+ raise TypeError(
499
+ f"Value for feature '{key}' with type {feature.dtype} must be a number"
500
+ )
501
+ elif feature.dtype.startswith("cat"):
502
+ if inferred_type != "?":
503
+ if not (inferred_type.startswith("cat") or isinstance(value, Record)):
504
+ raise TypeError(
505
+ f"Value for feature '{key}' with type '{feature.dtype}' must be a string or record."
506
+ )
507
+ elif not inferred_type == feature.dtype:
508
+ raise ValidationError(
509
+ f"Expected dtype for '{key}' is '{feature.dtype}', got '{inferred_type}'"
510
+ )
511
+ if not feature.dtype.startswith("cat"):
512
+ # can remove the query once we have the unique constraint
513
+ filter_kwargs = {model_name.lower(): feature, "value": converted_value}
514
+ feature_value = value_model.filter(**filter_kwargs).one_or_none()
515
+ if feature_value is None:
516
+ feature_value = value_model(**filter_kwargs)
517
+ _feature_values.append(feature_value)
518
+ else:
519
+ if isinstance(value, Record) or (
520
+ isinstance(value, Iterable) and isinstance(next(iter(value)), Record)
521
+ ):
522
+ if isinstance(value, Record):
523
+ label_records = [value]
524
+ else:
525
+ label_records = value # type: ignore
526
+ for record in label_records:
527
+ if record._state.adding:
528
+ raise ValidationError(
529
+ f"Please save {record} before annotation."
530
+ )
531
+ features_labels[record.__class__.__get_name_with_schema__()].append(
532
+ (feature, record)
533
+ )
534
+ else:
535
+ if isinstance(value, str):
536
+ values = [value] # type: ignore
537
+ else:
538
+ values = value # type: ignore
539
+ if "ULabel" not in feature.dtype:
540
+ feature.dtype += "[ULabel]"
541
+ feature.save()
542
+ validated = ULabel.validate(values, field="name", mute=True)
543
+ values_array = np.array(values)
544
+ validated_values = values_array[validated]
545
+ if validated.sum() != len(values):
546
+ not_validated_values += values_array[~validated].tolist()
547
+ label_records = ULabel.from_values(validated_values, field="name")
548
+ features_labels["ULabel"] += [
549
+ (feature, label_record) for label_record in label_records
550
+ ]
551
+ if not_validated_values:
552
+ hint = (
553
+ f" ulabels = ln.ULabel.from_values({not_validated_values}, create=True)\n"
554
+ f" ln.save(ulabels)"
555
+ )
556
+ msg = (
557
+ f"These values could not be validated: {not_validated_values}\n"
558
+ f"Here is how to create ulabels for them:\n\n{hint}"
559
+ )
560
+ raise ValidationError(msg)
561
+ # bulk add all links to ArtifactULabel
562
+ if features_labels:
563
+ if list(features_labels.keys()) != ["ULabel"]:
564
+ related_names = dict_related_model_to_related_name(self._host.__class__)
565
+ else:
566
+ related_names = {"ULabel": "ulabels"}
567
+ for class_name, registry_features_labels in features_labels.items():
568
+ related_name = related_names[class_name] # e.g., "ulabels"
569
+ LinkORM = getattr(self._host, related_name).through
570
+ field_name = f"{get_link_attr(LinkORM, self._host)}_id" # e.g., ulabel_id
571
+ links = [
572
+ LinkORM(
573
+ **{
574
+ "artifact_id": self._host.id,
575
+ "feature_id": feature.id,
576
+ field_name: label.id,
577
+ }
578
+ )
579
+ for (feature, label) in registry_features_labels
580
+ ]
581
+ # a link might already exist
582
+ try:
583
+ save(links, ignore_conflicts=False)
584
+ except Exception:
585
+ save(links, ignore_conflicts=True)
586
+ # now deal with links that were previously saved without a feature_id
587
+ links_saved = LinkORM.filter(
588
+ **{
589
+ "artifact_id": self._host.id,
590
+ f"{field_name}__in": [
591
+ l.id for _, l in registry_features_labels
592
+ ],
593
+ }
594
+ )
595
+ for link in links_saved.all():
596
+ # TODO: also check for inconsistent features
597
+ if link.feature_id is None:
598
+ link.feature_id = [
599
+ f.id
600
+ for f, l in registry_features_labels
601
+ if l.id == getattr(link, field_name)
602
+ ][0]
603
+ link.save()
604
+ if _feature_values:
605
+ save(_feature_values)
606
+ if is_param:
607
+ LinkORM = self._host._param_values.through
608
+ valuefield_id = "paramvalue_id"
609
+ else:
610
+ LinkORM = self._host._feature_values.through
611
+ valuefield_id = "featurevalue_id"
612
+ links = [
613
+ LinkORM(
614
+ **{
615
+ f"{self._host.__class__.__get_name_with_schema__().lower()}_id": self._host.id,
616
+ valuefield_id: feature_value.id,
617
+ }
618
+ )
619
+ for feature_value in _feature_values
620
+ ]
621
+ # a link might already exist, to avoid raising a unique constraint
622
+ # error, ignore_conflicts
623
+ save(links, ignore_conflicts=True)
624
+
625
+
626
+ def add_values_features(
627
+ self,
628
+ values: dict[str, str | int | float | bool],
629
+ feature_field: FieldAttr = Feature.name,
630
+ str_as_ulabel: bool = True,
631
+ ) -> None:
632
+ """Curate artifact with features & values.
633
+
634
+ Args:
635
+ values: A dictionary of keys (features) & values (labels, numbers, booleans).
636
+ feature_field: The field of a reference registry to map keys of the
637
+ dictionary.
638
+ str_as_ulabel: Whether to interpret string values as ulabels.
639
+ """
640
+ _add_values(self, values, feature_field, str_as_ulabel=str_as_ulabel)
641
+
642
+
643
+ def add_values_params(
644
+ self,
645
+ values: dict[str, str | int | float | bool],
646
+ ) -> None:
647
+ """Curate artifact with features & values.
648
+
649
+ Args:
650
+ values: A dictionary of keys (features) & values (labels, numbers, booleans).
651
+ """
652
+ _add_values(self, values, Param.name, str_as_ulabel=False)
653
+
654
+
655
+ def add_feature_set(self, feature_set: FeatureSet, slot: str) -> None:
656
+ """Curate artifact with a feature set.
657
+
658
+ Args:
659
+ feature_set: `FeatureSet` A feature set record.
660
+ slot: `str` The slot that marks where the feature set is stored in
661
+ the artifact.
662
+ """
663
+ if self._host._state.adding:
664
+ raise ValueError(
665
+ "Please save the artifact or collection before adding a feature set!"
666
+ )
667
+ host_db = self._host._state.db
668
+ feature_set.save(using=host_db)
669
+ host_id_field = get_host_id_field(self._host)
670
+ kwargs = {
671
+ host_id_field: self._host.id,
672
+ "featureset": feature_set,
673
+ "slot": slot,
674
+ }
675
+ link_record = (
676
+ self._host.feature_sets.through.objects.using(host_db)
677
+ .filter(**kwargs)
678
+ .one_or_none()
679
+ )
680
+ if link_record is None:
681
+ self._host.feature_sets.through(**kwargs).save(using=host_db)
682
+ if slot in self._feature_set_by_slot:
683
+ logger.debug(f"replaced existing {slot} feature set")
684
+ self._feature_set_by_slot_[slot] = feature_set # type: ignore
685
+
686
+
687
+ def _add_set_from_df(
688
+ self, field: FieldAttr = Feature.name, organism: str | None = None
689
+ ):
690
+ """Add feature set corresponding to column names of DataFrame."""
691
+ if isinstance(self._host, Artifact):
692
+ assert self._host._accessor == "DataFrame" # noqa: S101
693
+ else:
694
+ # Collection
695
+ assert self._host.artifact._accessor == "DataFrame" # noqa: S101
696
+
697
+ # parse and register features
698
+ registry = field.field.model
699
+ df = self._host.load()
700
+ features = registry.from_values(df.columns, field=field, organism=organism)
701
+ if len(features) == 0:
702
+ logger.error(
703
+ "no validated features found in DataFrame! please register features first!"
704
+ )
705
+ return
706
+
707
+ # create and link feature sets
708
+ feature_set = FeatureSet(features=features)
709
+ feature_sets = {"columns": feature_set}
710
+ self._host._feature_sets = feature_sets
711
+ self._host.save()
712
+
713
+
714
+ def _add_set_from_anndata(
715
+ self,
716
+ var_field: FieldAttr,
717
+ obs_field: FieldAttr | None = Feature.name,
718
+ mute: bool = False,
719
+ organism: str | Record | None = None,
720
+ ):
721
+ """Add features from AnnData."""
722
+ if isinstance(self._host, Artifact):
723
+ assert self._host._accessor == "AnnData" # noqa: S101
724
+ else:
725
+ raise NotImplementedError()
726
+
727
+ # parse and register features
728
+ adata = self._host.load()
729
+ feature_sets = parse_feature_sets_from_anndata(
730
+ adata,
731
+ var_field=var_field,
732
+ obs_field=obs_field,
733
+ mute=mute,
734
+ organism=organism,
735
+ )
736
+
737
+ # link feature sets
738
+ self._host._feature_sets = feature_sets
739
+ self._host.save()
740
+
741
+
742
+ def _add_set_from_mudata(
743
+ self,
744
+ var_fields: dict[str, FieldAttr],
745
+ obs_fields: dict[str, FieldAttr] = None,
746
+ mute: bool = False,
747
+ organism: str | Record | None = None,
748
+ ):
749
+ """Add features from MuData."""
750
+ if obs_fields is None:
751
+ obs_fields = {}
752
+ if isinstance(self._host, Artifact):
753
+ assert self._host._accessor == "MuData" # noqa: S101
754
+ else:
755
+ raise NotImplementedError()
756
+
757
+ # parse and register features
758
+ mdata = self._host.load()
759
+ feature_sets = {}
760
+ obs_features = Feature.from_values(mdata.obs.columns)
761
+ if len(obs_features) > 0:
762
+ feature_sets["obs"] = FeatureSet(features=obs_features)
763
+ for modality, field in var_fields.items():
764
+ modality_fs = parse_feature_sets_from_anndata(
765
+ mdata[modality],
766
+ var_field=field,
767
+ obs_field=obs_fields.get(modality, Feature.name),
768
+ mute=mute,
769
+ organism=organism,
770
+ )
771
+ for k, v in modality_fs.items():
772
+ feature_sets[f"['{modality}'].{k}"] = v
773
+
774
+ def unify_feature_sets_by_hash(feature_sets):
775
+ unique_values = {}
776
+
777
+ for key, value in feature_sets.items():
778
+ value_hash = value.hash # Assuming each value has a .hash attribute
779
+ if value_hash in unique_values:
780
+ feature_sets[key] = unique_values[value_hash]
781
+ else:
782
+ unique_values[value_hash] = value
783
+
784
+ return feature_sets
785
+
786
+ # link feature sets
787
+ self._host._feature_sets = unify_feature_sets_by_hash(feature_sets)
788
+ self._host.save()
789
+
790
+
791
+ def _add_from(self, data: Artifact | Collection, transfer_logs: dict = None):
792
+ """Transfer features from a artifact or collection."""
793
+ # This only covers feature sets
794
+ if transfer_logs is None:
795
+ transfer_logs = {"mapped": [], "transferred": []}
796
+ using_key = settings._using_key
797
+ for slot, feature_set in data.features._feature_set_by_slot.items():
798
+ members = feature_set.members
799
+ if len(members) == 0:
800
+ continue
801
+ registry = members[0].__class__
802
+ # note here the features are transferred based on an unique field
803
+ field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
804
+ if hasattr(registry, "_ontology_id_field"):
805
+ field = registry._ontology_id_field
806
+ # this will be e.g. be a list of ontology_ids or uids
807
+ member_uids = list(members.values_list(field, flat=True))
808
+ # create records from ontology_id
809
+ if hasattr(registry, "_ontology_id_field") and len(member_uids) > 0:
810
+ # create from bionty
811
+ members_records = registry.from_values(member_uids, field=field)
812
+ save([r for r in members_records if r._state.adding])
813
+ validated = registry.validate(member_uids, field=field, mute=True)
814
+ new_members_uids = list(compress(member_uids, ~validated))
815
+ new_members = members.filter(**{f"{field}__in": new_members_uids}).all()
816
+ n_new_members = len(new_members)
817
+ if n_new_members > 0:
818
+ # transfer foreign keys needs to be run before transfer to default db
819
+ transfer_fk_to_default_db_bulk(
820
+ new_members, using_key, transfer_logs=transfer_logs
821
+ )
822
+ for feature in new_members:
823
+ # not calling save=True here as in labels, because want to
824
+ # bulk save below
825
+ # transfer_fk is set to False because they are already transferred
826
+ # in the previous step transfer_fk_to_default_db_bulk
827
+ transfer_to_default_db(
828
+ feature, using_key, transfer_fk=False, transfer_logs=transfer_logs
829
+ )
830
+ logger.info(f"saving {n_new_members} new {registry.__name__} records")
831
+ save(new_members)
832
+
833
+ # create a new feature set from feature values using the same uid
834
+ feature_set_self = FeatureSet.from_values(
835
+ member_uids, field=getattr(registry, field)
836
+ )
837
+ if feature_set_self is None:
838
+ if hasattr(registry, "organism_id"):
839
+ logger.warning(
840
+ f"FeatureSet is not transferred, check if organism is set correctly: {feature_set}"
841
+ )
842
+ continue
843
+ # make sure the uid matches if featureset is composed of same features
844
+ if feature_set_self.hash == feature_set.hash:
845
+ feature_set_self.uid = feature_set.uid
846
+ logger.info(f"saving {slot} featureset: {feature_set_self}")
847
+ self._host.features.add_feature_set(feature_set_self, slot)
848
+
849
+
850
+ FeatureManager.__init__ = __init__
851
+ ParamManager.__init__ = __init__
852
+ FeatureManager.__repr__ = __repr__
853
+ ParamManager.__repr__ = __repr__
854
+ FeatureManager.__getitem__ = __getitem__
855
+ FeatureManager.get_values = get_values
856
+ FeatureManager._feature_set_by_slot = _feature_set_by_slot
857
+ FeatureManager._accessor_by_registry = _accessor_by_registry
858
+ FeatureManager.add_values = add_values_features
859
+ FeatureManager.add_feature_set = add_feature_set
860
+ FeatureManager._add_set_from_df = _add_set_from_df
861
+ FeatureManager._add_set_from_anndata = _add_set_from_anndata
862
+ FeatureManager._add_set_from_mudata = _add_set_from_mudata
863
+ FeatureManager._add_from = _add_from
864
+ FeatureManager.filter = filter
865
+ FeatureManager.get = get
866
+ ParamManager.add_values = add_values_params
867
+ ParamManager.get_values = get_values