lamindb 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,24 +1,36 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import defaultdict
4
+ from collections.abc import Iterable
3
5
  from itertools import compress
4
- from typing import TYPE_CHECKING, Iterable
6
+ from typing import TYPE_CHECKING, Any
5
7
 
6
8
  import anndata as ad
9
+ import numpy as np
10
+ import pandas as pd
7
11
  from anndata import AnnData
12
+ from django.contrib.postgres.aggregates import ArrayAgg
13
+ from django.db import connections, models
14
+ from django.db.models import Aggregate, CharField, F, Value
15
+ from django.db.models.functions import Concat
8
16
  from lamin_utils import colors, logger
9
17
  from lamindb_setup.core.upath import create_path
10
18
  from lnschema_core.models import (
11
19
  Artifact,
12
20
  Collection,
13
- Data,
14
21
  Feature,
22
+ FeatureManager,
23
+ FeatureManagerArtifact,
24
+ FeatureManagerCollection,
15
25
  FeatureValue,
26
+ HasFeatures,
27
+ LinkORM,
16
28
  Registry,
17
29
  ULabel,
18
30
  )
19
31
 
20
- from lamindb._feature import convert_numpy_dtype_to_lamin_feature_type
21
- from lamindb._feature_set import FeatureSet
32
+ from lamindb._feature import FEATURE_TYPES, convert_numpy_dtype_to_lamin_feature_type
33
+ from lamindb._feature_set import DICT_KEYS_TYPE, FeatureSet
22
34
  from lamindb._registry import (
23
35
  REGISTRY_UNIQUE_FIELD,
24
36
  get_default_str_field,
@@ -29,6 +41,7 @@ from lamindb._save import save
29
41
  from lamindb.core.exceptions import ValidationError
30
42
  from lamindb.core.storage import LocalPathClasses
31
43
 
44
+ from ._label_manager import get_labels_as_dict
32
45
  from ._settings import settings
33
46
 
34
47
  if TYPE_CHECKING:
@@ -45,7 +58,7 @@ def get_host_id_field(host: Artifact | Collection) -> str:
45
58
  return host_id_field
46
59
 
47
60
 
48
- def get_accessor_by_orm(host: Artifact | Collection) -> dict:
61
+ def get_accessor_by_registry_(host: Artifact | Collection) -> dict:
49
62
  dictionary = {
50
63
  field.related_model.__get_name_with_schema__(): field.name
51
64
  for field in host._meta.related_objects
@@ -55,7 +68,7 @@ def get_accessor_by_orm(host: Artifact | Collection) -> dict:
55
68
  return dictionary
56
69
 
57
70
 
58
- def get_feature_set_by_slot(host) -> dict:
71
+ def get_feature_set_by_slot_(host) -> dict:
59
72
  # if the host is not yet saved
60
73
  if host._state.adding:
61
74
  if hasattr(host, "_feature_sets"):
@@ -80,7 +93,7 @@ def get_label_links(
80
93
  host_id_field = get_host_id_field(host)
81
94
  kwargs = {host_id_field: host.id, "feature_id": feature.id}
82
95
  link_records = (
83
- getattr(host, host.features.accessor_by_orm[registry])
96
+ getattr(host, host.features._accessor_by_registry[registry])
84
97
  .through.objects.using(host._state.db)
85
98
  .filter(**kwargs)
86
99
  )
@@ -94,53 +107,95 @@ def get_feature_set_links(host: Artifact | Collection) -> QuerySet:
94
107
  return feature_set_links
95
108
 
96
109
 
97
- def print_features(self: Data) -> str:
110
+ def get_link_attr(link: LinkORM, data: HasFeatures) -> str:
111
+ link_model_name = link.__class__.__name__
112
+ link_attr = link_model_name.replace(data.__class__.__name__, "")
113
+ if link_attr == "ExperimentalFactor":
114
+ link_attr = "experimental_factor"
115
+ else:
116
+ link_attr = link_attr.lower()
117
+ return link_attr
118
+
119
+
120
+ # Custom aggregation for SQLite
121
+ class GroupConcat(Aggregate):
122
+ function = "GROUP_CONCAT"
123
+ template = '%(function)s(%(expressions)s, ", ")'
124
+
125
+
126
+ def custom_aggregate(field, using: str):
127
+ if connections[using].vendor == "postgresql":
128
+ return ArrayAgg(field)
129
+ else:
130
+ return GroupConcat(field)
131
+
132
+
133
+ def print_features(
134
+ self: HasFeatures, print_types: bool = False, to_dict: bool = False
135
+ ) -> str | dict[str, Any]:
98
136
  from lamindb._from_values import _print_values
99
137
 
100
- from ._data import format_repr
101
-
102
- messages = []
103
- for slot, feature_set in get_feature_set_by_slot(self).items():
104
- if feature_set.registry != "Feature":
105
- features = feature_set.members
106
- # features.first() is a lot slower than features[0] here
107
- name_field = get_default_str_field(features[0])
108
- feature_names = list(features.values_list(name_field, flat=True)[:30])
109
- messages.append(
110
- f" {colors.bold(slot)}: {format_repr(feature_set, exclude='hash')}\n"
111
- )
112
- print_values = _print_values(feature_names, n=20)
113
- messages.append(f" {print_values}\n")
114
- else:
115
- features_lookup = {
116
- f.name: f for f in Feature.objects.using(self._state.db).filter().all()
117
- }
118
- messages.append(
119
- f" {colors.bold(slot)}: {format_repr(feature_set, exclude='hash')}\n"
138
+ msg = ""
139
+ dictionary = {}
140
+ # categorical feature values
141
+ labels_msg = ""
142
+ labels_by_feature = defaultdict(list)
143
+ for _, (_, links) in get_labels_as_dict(self, links=True).items():
144
+ for link in links:
145
+ if link.feature_id is not None:
146
+ link_attr = get_link_attr(link, self)
147
+ labels_by_feature[link.feature_id].append(getattr(link, link_attr).name)
148
+ for feature_id, labels_list in labels_by_feature.items():
149
+ feature = Feature.objects.using(self._state.db).get(id=feature_id)
150
+ print_values = _print_values(labels_list, n=10)
151
+ type_str = f": {feature.dtype}" if print_types else ""
152
+ if to_dict:
153
+ dictionary[feature.name] = (
154
+ labels_list if len(labels_list) > 1 else labels_list[0]
120
155
  )
121
- for name, dtype in feature_set.features.values_list("name", "dtype"):
122
- if dtype.startswith("cat["):
123
- labels = self.labels.get(features_lookup.get(name), mute=True)
124
- indent = ""
125
- if isinstance(labels, dict):
126
- messages.append(f" 🔗 {name} ({dtype})\n")
127
- indent = " "
128
- else:
129
- labels = {dtype: labels}
130
- for registry, registry_labels in labels.items():
131
- field = get_default_str_field(registry_labels)
132
- values_list = registry_labels.values_list(field, flat=True)
133
- count_str = f"{feature_set.n}, {colors.italic(f'{registry}')}"
134
- print_values = _print_values(values_list[:20], n=10)
135
- msg_objects = (
136
- f"{indent} 🔗 {name} ({count_str}):" f" {print_values}\n"
137
- )
138
- messages.append(msg_objects)
139
- else:
140
- messages.append(f" {name} ({dtype})\n")
141
- if messages:
142
- messages.insert(0, f"{colors.green('Features')}:\n")
143
- return "".join(messages)
156
+ labels_msg += f" '{feature.name}'{type_str} = {print_values}\n"
157
+ if labels_msg:
158
+ msg += labels_msg
159
+
160
+ # non-categorical feature values
161
+ non_labels_msg = ""
162
+ if self.id is not None and self.__class__ == Artifact:
163
+ feature_values = self.feature_values.values(
164
+ "feature__name", "feature__dtype"
165
+ ).annotate(values=custom_aggregate("value", self._state.db))
166
+ if len(feature_values) > 0:
167
+ for fv in feature_values:
168
+ feature_name = fv["feature__name"]
169
+ feature_dtype = fv["feature__dtype"]
170
+ values = fv["values"]
171
+ # TODO: understand why the below is necessary
172
+ if not isinstance(values, list):
173
+ values = [values]
174
+ if to_dict:
175
+ dictionary[feature_name] = values if len(values) > 1 else values[0]
176
+ type_str = f": {feature_dtype}" if print_types else ""
177
+ non_labels_msg += f" '{feature_name}'{type_str} = {_print_values(values, n=10, quotes=False)}\n"
178
+ msg += non_labels_msg
179
+
180
+ if msg != "":
181
+ msg = f" {colors.italic('Features')}\n" + msg
182
+
183
+ # feature sets
184
+ feature_set_msg = ""
185
+ for slot, feature_set in get_feature_set_by_slot_(self).items():
186
+ features = feature_set.members
187
+ # features.first() is a lot slower than features[0] here
188
+ name_field = get_default_str_field(features[0])
189
+ feature_names = list(features.values_list(name_field, flat=True)[:20])
190
+ type_str = f": {feature_set.registry}" if print_types else ""
191
+ feature_set_msg += f" '{slot}'{type_str} = {_print_values(feature_names)}\n"
192
+ if feature_set_msg:
193
+ msg += f" {colors.italic('Feature sets')}\n"
194
+ msg += feature_set_msg
195
+ if to_dict:
196
+ return dictionary
197
+ else:
198
+ return msg
144
199
 
145
200
 
146
201
  def parse_feature_sets_from_anndata(
@@ -203,305 +258,443 @@ def parse_feature_sets_from_anndata(
203
258
  return feature_sets
204
259
 
205
260
 
206
- class FeatureManager:
207
- """Feature manager (:attr:`~lamindb.core.Data.features`).
261
+ def infer_feature_type_convert_json(value: Any, mute: bool = False) -> tuple[str, Any]:
262
+ if isinstance(value, bool):
263
+ return FEATURE_TYPES["bool"], value
264
+ elif isinstance(value, int):
265
+ return FEATURE_TYPES["int"], value
266
+ elif isinstance(value, float):
267
+ return FEATURE_TYPES["float"], value
268
+ elif isinstance(value, str):
269
+ return FEATURE_TYPES["str"] + "[ULabel]", value
270
+ elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
271
+ if isinstance(value, (pd.Series, np.ndarray)):
272
+ return convert_numpy_dtype_to_lamin_feature_type(value.dtype), list(value)
273
+ if len(value) > 0: # type: ignore
274
+ first_element_type = type(next(iter(value)))
275
+ if all(isinstance(elem, first_element_type) for elem in value):
276
+ if first_element_type == bool:
277
+ return FEATURE_TYPES["bool"], value
278
+ elif first_element_type == int:
279
+ return FEATURE_TYPES["int"], value
280
+ elif first_element_type == float:
281
+ return FEATURE_TYPES["float"], value
282
+ elif first_element_type == str:
283
+ return FEATURE_TYPES["str"] + "[ULabel]", value
284
+ if not mute:
285
+ logger.warning(f"cannot infer feature type of: {value}, returning '?")
286
+ return ("?", value)
287
+
288
+
289
+ def __init__(self, host: Artifact | Collection):
290
+ self._host = host
291
+ self._feature_set_by_slot_ = None
292
+ self._accessor_by_registry_ = None
293
+
294
+
295
+ def __repr__(self) -> str:
296
+ return print_features(self._host) # type: ignore
297
+
298
+
299
+ def get_values(self) -> dict[str, Any]:
300
+ """Get feature values as a dictionary."""
301
+ return print_features(self._host, to_dict=True) # type: ignore
302
+
303
+
304
+ def __getitem__(self, slot) -> QuerySet:
305
+ if slot not in self._feature_set_by_slot:
306
+ raise ValueError(
307
+ f"No linked feature set for slot: {slot}\nDid you get validation"
308
+ " warnings? Only features that match registered features get validated"
309
+ " and linked."
310
+ )
311
+ feature_set = self._feature_set_by_slot[slot]
312
+ orm_name = feature_set.registry
313
+ return getattr(feature_set, self._accessor_by_registry[orm_name]).all()
314
+
315
+
316
+ @classmethod # type: ignore
317
+ def filter(cls, **expression) -> QuerySet:
318
+ """Filter features."""
319
+ keys_normalized = [key.split("__")[0] for key in expression]
320
+ validated = Feature.validate(keys_normalized, field="name", mute=True)
321
+ if sum(validated) != len(keys_normalized):
322
+ raise ValidationError(
323
+ f"Some keys in the filter expression are not registered as features: {np.array(keys_normalized)[~validated]}"
324
+ )
325
+ new_expression = {}
326
+ features = Feature.filter(name__in=keys_normalized).all().distinct()
327
+ for key, value in expression.items():
328
+ normalized_key = key.split("__")[0]
329
+ feature = features.get(name=normalized_key)
330
+ if not feature.dtype.startswith("cat"):
331
+ feature_value = FeatureValue.filter(feature=feature, value=value).one()
332
+ new_expression["feature_values"] = feature_value
333
+ else:
334
+ if isinstance(value, str):
335
+ label = ULabel.filter(name=value).one()
336
+ new_expression["ulabels"] = label
337
+ else:
338
+ raise NotImplementedError
339
+ if cls == FeatureManagerArtifact:
340
+ return Artifact.filter(**new_expression)
341
+ else:
342
+ return Collection.filter(**new_expression)
208
343
 
209
- See :class:`~lamindb.core.Data` for more information.
210
- """
211
344
 
212
- def __init__(self, host: Artifact | Collection):
213
- self._host = host
214
- self._feature_set_by_slot = None
215
- self._accessor_by_orm = None
345
+ @property # type: ignore
346
+ def _feature_set_by_slot(self):
347
+ """Feature sets by slot."""
348
+ if self._feature_set_by_slot_ is None:
349
+ self._feature_set_by_slot_ = get_feature_set_by_slot_(self._host)
350
+ return self._feature_set_by_slot_
216
351
 
217
- def __repr__(self) -> str:
218
- if len(self.feature_set_by_slot) > 0:
219
- return print_features(self._host)
220
- else:
221
- return "no linked features"
222
-
223
- def __getitem__(self, slot) -> QuerySet:
224
- if slot not in self.feature_set_by_slot:
225
- raise ValueError(
226
- f"No linked feature set for slot: {slot}\nDid you get validation"
227
- " warnings? Only features that match registered features get validated"
228
- " and linked."
229
- )
230
- feature_set = self.feature_set_by_slot[slot]
231
- orm_name = feature_set.registry
232
- if hasattr(feature_set, "_features"):
233
- # feature set is not yet saved
234
- # need to think about turning this into a queryset
235
- return feature_set._features
352
+
353
+ @property # type: ignore
354
+ def _accessor_by_registry(self):
355
+ """Accessor by ORM."""
356
+ if self._accessor_by_registry_ is None:
357
+ self._accessor_by_registry_ = get_accessor_by_registry_(self._host)
358
+ return self._accessor_by_registry_
359
+
360
+
361
+ def add_values(
362
+ self,
363
+ values: dict[str, str | int | float | bool],
364
+ feature_field: FieldAttr = Feature.name,
365
+ ) -> None:
366
+ """Annotate artifact with features & values.
367
+
368
+ Args:
369
+ values: A dictionary of keys (features) & values (labels, numbers, booleans).
370
+ feature_field: The field of a reference registry to map keys of the
371
+ dictionary.
372
+ """
373
+ # rename to distinguish from the values inside the dict
374
+ features_values = values
375
+ keys = features_values.keys()
376
+ if isinstance(keys, DICT_KEYS_TYPE):
377
+ keys = list(keys) # type: ignore
378
+ # deal with other cases later
379
+ assert all(isinstance(key, str) for key in keys)
380
+ registry = feature_field.field.model
381
+ validated = registry.validate(keys, field=feature_field, mute=True)
382
+ keys_array = np.array(keys)
383
+ validated_keys = keys_array[validated]
384
+ if validated.sum() != len(keys):
385
+ not_validated_keys = keys_array[~validated]
386
+ hint = "\n".join(
387
+ [
388
+ f" ln.Feature(name='{key}', dtype='{infer_feature_type_convert_json(features_values[key])[0]}').save()"
389
+ for key in not_validated_keys
390
+ ]
391
+ )
392
+ msg = (
393
+ f"These keys could not be validated: {not_validated_keys.tolist()}\n"
394
+ f"If there are no typos, create features for them:\n\n{hint}"
395
+ )
396
+ raise ValidationError(msg)
397
+ registry.from_values(
398
+ validated_keys,
399
+ field=feature_field,
400
+ )
401
+ # figure out which of the values go where
402
+ features_labels = []
403
+ feature_values = []
404
+ not_validated_values = []
405
+ for key, value in features_values.items():
406
+ feature = Feature.filter(name=key).one()
407
+ inferred_type, converted_value = infer_feature_type_convert_json(
408
+ value, mute=True
409
+ )
410
+ if feature.dtype == "number":
411
+ if inferred_type not in {"int", "float"}:
412
+ raise TypeError(
413
+ f"Value for feature '{key}' with type {feature.dtype} must be a number"
414
+ )
415
+ elif feature.dtype == "cat":
416
+ if not (inferred_type.startswith("cat") or isinstance(value, Registry)):
417
+ raise TypeError(
418
+ f"Value for feature '{key}' with type '{feature.dtype}' must be a string or record."
419
+ )
420
+ elif feature.dtype == "bool":
421
+ assert isinstance(value, bool)
422
+ if not feature.dtype.startswith("cat"):
423
+ # can remove the query once we have the unique constraint
424
+ feature_value = FeatureValue.filter(
425
+ feature=feature, value=converted_value
426
+ ).one_or_none()
427
+ if feature_value is None:
428
+ feature_value = FeatureValue(feature=feature, value=converted_value)
429
+ feature_values.append(feature_value)
236
430
  else:
237
- return getattr(feature_set, self.accessor_by_orm[orm_name]).all()
238
-
239
- @property
240
- def feature_set_by_slot(self):
241
- """Feature sets by slot."""
242
- if self._feature_set_by_slot is None:
243
- self._feature_set_by_slot = get_feature_set_by_slot(self._host)
244
- return self._feature_set_by_slot
245
-
246
- @property
247
- def accessor_by_orm(self):
248
- """Accessor by ORM."""
249
- if self._accessor_by_orm is None:
250
- self._accessor_by_orm = get_accessor_by_orm(self._host)
251
- return self._accessor_by_orm
252
-
253
- def add(
254
- self,
255
- features_values: dict[str, str | int | float | bool],
256
- slot: str | None = None,
257
- feature_field: FieldAttr = Feature.name,
258
- ):
259
- """Add features stratified by slot.
260
-
261
- Args:
262
- features_values: A dictionary of features & values. You can also
263
- pass `{feature_identifier: None}` to skip annotation by values.
264
- slot: The access slot of the feature sets in the artifact. For
265
- instance, `.columns` for `DataFrame` or `.var` or `.obs` for
266
- `AnnData`.
267
- feature_field: The field of a reference registry to map values.
268
- """
269
- if slot is None:
270
- slot = "external"
271
- keys = features_values.keys()
272
- features_values.values()
273
- # what if the feature is already part of a linked feature set?
274
- # what if artifact annotation by features through link tables and through feature sets
275
- # differs?
276
- feature_set = FeatureSet.from_values(keys, field=feature_field)
277
- self._host.features.add_feature_set(feature_set, slot)
278
- # now figure out which of the values go where
279
- features_labels = []
280
- feature_values = []
281
- for key, value in features_values.items():
282
- # TODO: use proper field in .get() below
283
- feature = feature_set.features.get(name=key)
284
- if feature.dtype == "number":
285
- if not (isinstance(value, int) or isinstance(value, float)):
286
- raise TypeError(
287
- f"Value for feature '{key}' with type {feature.dtype} must be a number"
288
- )
289
- elif feature.dtype == "cat":
290
- if not (isinstance(value, str) or isinstance(value, Registry)):
291
- raise TypeError(
292
- f"Value for feature '{key}' with type '{feature.dtype}' must be a string or record."
293
- )
294
- elif feature.dtype == "bool":
295
- assert isinstance(value, bool)
296
- if feature.dtype == "cat":
297
- if isinstance(value, Registry):
298
- assert not value._state.adding
299
- label_record = value
300
- assert isinstance(label_record, ULabel)
301
- else:
302
- label_record = ULabel.filter(name=value).one_or_none()
303
- if label_record is None:
304
- raise ValidationError(f"Label '{value}' not found in ln.ULabel")
431
+ if isinstance(value, Registry):
432
+ assert not value._state.adding
433
+ label_record = value
434
+ assert isinstance(label_record, ULabel)
305
435
  features_labels.append((feature, label_record))
306
436
  else:
307
- feature_values.append(FeatureValue(feature=feature, value=value))
308
- # bulk add all links to ArtifactULabel
309
- if features_labels:
310
- LinkORM = self._host.ulabels.through
311
- links = [
312
- LinkORM(
313
- artifact_id=self._host.id, feature_id=feature.id, ulabel_id=label.id
314
- )
315
- for (feature, label) in features_labels
316
- ]
317
- LinkORM.objects.bulk_create(links, ignore_conflicts=True)
318
- if feature_values:
319
- save(feature_values)
320
- LinkORM = self._host.feature_values.through
321
- links = [
322
- LinkORM(artifact_id=self._host.id, featurevalue_id=feature_value.id)
323
- for feature_value in feature_values
324
- ]
325
- LinkORM.objects.bulk_create(links)
326
-
327
- def add_from_df(self, field: FieldAttr = Feature.name, organism: str | None = None):
328
- """Add features from DataFrame."""
329
- if isinstance(self._host, Artifact):
330
- assert self._host.accessor == "DataFrame"
331
- else:
332
- # Collection
333
- assert self._host.artifact.accessor == "DataFrame"
334
-
335
- # parse and register features
336
- registry = field.field.model
337
- df = self._host.load()
338
- features = registry.from_values(df.columns, field=field, organism=organism)
339
- if len(features) == 0:
340
- logger.error(
341
- "no validated features found in DataFrame! please register features first!"
437
+ if isinstance(value, str):
438
+ values = [value] # type: ignore
439
+ else:
440
+ values = value # type: ignore
441
+ if "ULabel" not in feature.dtype:
442
+ feature.dtype += "[ULabel]"
443
+ feature.save()
444
+ validated = ULabel.validate(values, field="name", mute=True)
445
+ values_array = np.array(values)
446
+ validated_values = values_array[validated]
447
+ if validated.sum() != len(values):
448
+ not_validated_values += values_array[~validated].tolist()
449
+ label_records = ULabel.from_values(validated_values, field="name")
450
+ features_labels += [
451
+ (feature, label_record) for label_record in label_records
452
+ ]
453
+ if not_validated_values:
454
+ hint = (
455
+ f" ulabels = ln.ULabel.from_values({not_validated_values}, create=True)\n"
456
+ f" ln.save(ulabels)"
457
+ )
458
+ msg = (
459
+ f"These values could not be validated: {not_validated_values}\n"
460
+ f"If there are no typos, create ulabels for them:\n\n{hint}"
461
+ )
462
+ raise ValidationError(msg)
463
+ # bulk add all links to ArtifactULabel
464
+ if features_labels:
465
+ LinkORM = self._host.ulabels.through
466
+ links = [
467
+ LinkORM(
468
+ artifact_id=self._host.id, feature_id=feature.id, ulabel_id=label.id
342
469
  )
343
- return
344
-
345
- # create and link feature sets
346
- feature_set = FeatureSet(features=features)
347
- feature_sets = {"columns": feature_set}
348
- self._host._feature_sets = feature_sets
349
- self._host.save()
350
-
351
- def add_from_anndata(
352
- self,
353
- var_field: FieldAttr,
354
- obs_field: FieldAttr | None = Feature.name,
355
- mute: bool = False,
356
- organism: str | Registry | None = None,
357
- ):
358
- """Add features from AnnData."""
359
- if isinstance(self._host, Artifact):
360
- assert self._host.accessor == "AnnData"
361
- else:
362
- raise NotImplementedError()
363
-
364
- # parse and register features
365
- adata = self._host.load()
366
- feature_sets = parse_feature_sets_from_anndata(
367
- adata,
368
- var_field=var_field,
369
- obs_field=obs_field,
470
+ for (feature, label) in features_labels
471
+ ]
472
+ # a link might already exist
473
+ try:
474
+ save(links, ignore_conflicts=False)
475
+ except Exception:
476
+ save(links, ignore_conflicts=True)
477
+ # now deal with links that were previously saved without a feature_id
478
+ saved_links = LinkORM.filter(
479
+ artifact_id=self._host.id,
480
+ ulabel_id__in=[l.id for _, l in features_labels],
481
+ )
482
+ for link in saved_links.all():
483
+ # TODO: also check for inconsistent features
484
+ if link.feature_id is None:
485
+ link.feature_id = [
486
+ f.id for f, l in features_labels if l.id == link.ulabel_id
487
+ ][0]
488
+ link.save()
489
+ if feature_values:
490
+ save(feature_values)
491
+ LinkORM = self._host.feature_values.through
492
+ links = [
493
+ LinkORM(artifact_id=self._host.id, featurevalue_id=feature_value.id)
494
+ for feature_value in feature_values
495
+ ]
496
+ # a link might already exist, to avoid raising a unique constraint
497
+ # error, ignore_conflicts
498
+ save(links, ignore_conflicts=True)
499
+
500
+
501
+ def add_feature_set(self, feature_set: FeatureSet, slot: str) -> None:
502
+ """Annotate artifact with a feature set.
503
+
504
+ Args:
505
+ feature_set: `FeatureSet` A feature set record.
506
+ slot: `str` The slot that marks where the feature set is stored in
507
+ the artifact.
508
+ """
509
+ if self._host._state.adding:
510
+ raise ValueError(
511
+ "Please save the artifact or collection before adding a feature set!"
512
+ )
513
+ host_db = self._host._state.db
514
+ feature_set.save(using=host_db)
515
+ host_id_field = get_host_id_field(self._host)
516
+ kwargs = {
517
+ host_id_field: self._host.id,
518
+ "featureset": feature_set,
519
+ "slot": slot,
520
+ }
521
+ link_record = (
522
+ self._host.feature_sets.through.objects.using(host_db)
523
+ .filter(**kwargs)
524
+ .one_or_none()
525
+ )
526
+ if link_record is None:
527
+ self._host.feature_sets.through(**kwargs).save(using=host_db)
528
+ if slot in self._feature_set_by_slot:
529
+ logger.debug(f"replaced existing {slot} feature set")
530
+ self._feature_set_by_slot_[slot] = feature_set # type: ignore
531
+
532
+
533
+ def _add_set_from_df(
534
+ self, field: FieldAttr = Feature.name, organism: str | None = None
535
+ ):
536
+ """Add feature set corresponding to column names of DataFrame."""
537
+ if isinstance(self._host, Artifact):
538
+ assert self._host.accessor == "DataFrame"
539
+ else:
540
+ # Collection
541
+ assert self._host.artifact.accessor == "DataFrame"
542
+
543
+ # parse and register features
544
+ registry = field.field.model
545
+ df = self._host.load()
546
+ features = registry.from_values(df.columns, field=field, organism=organism)
547
+ if len(features) == 0:
548
+ logger.error(
549
+ "no validated features found in DataFrame! please register features first!"
550
+ )
551
+ return
552
+
553
+ # create and link feature sets
554
+ feature_set = FeatureSet(features=features)
555
+ feature_sets = {"columns": feature_set}
556
+ self._host._feature_sets = feature_sets
557
+ self._host.save()
558
+
559
+
560
+ def _add_set_from_anndata(
561
+ self,
562
+ var_field: FieldAttr,
563
+ obs_field: FieldAttr | None = Feature.name,
564
+ mute: bool = False,
565
+ organism: str | Registry | None = None,
566
+ ):
567
+ """Add features from AnnData."""
568
+ if isinstance(self._host, Artifact):
569
+ assert self._host.accessor == "AnnData"
570
+ else:
571
+ raise NotImplementedError()
572
+
573
+ # parse and register features
574
+ adata = self._host.load()
575
+ feature_sets = parse_feature_sets_from_anndata(
576
+ adata,
577
+ var_field=var_field,
578
+ obs_field=obs_field,
579
+ mute=mute,
580
+ organism=organism,
581
+ )
582
+
583
+ # link feature sets
584
+ self._host._feature_sets = feature_sets
585
+ self._host.save()
586
+
587
+
588
+ def _add_set_from_mudata(
589
+ self,
590
+ var_fields: dict[str, FieldAttr],
591
+ obs_fields: dict[str, FieldAttr] = None,
592
+ mute: bool = False,
593
+ organism: str | Registry | None = None,
594
+ ):
595
+ """Add features from MuData."""
596
+ if obs_fields is None:
597
+ obs_fields = {}
598
+ if isinstance(self._host, Artifact):
599
+ assert self._host.accessor == "MuData"
600
+ else:
601
+ raise NotImplementedError()
602
+
603
+ # parse and register features
604
+ mdata = self._host.load()
605
+ feature_sets = {}
606
+ obs_features = features = Feature.from_values(mdata.obs.columns)
607
+ if len(obs_features) > 0:
608
+ feature_sets["obs"] = FeatureSet(features=features)
609
+ for modality, field in var_fields.items():
610
+ modality_fs = parse_feature_sets_from_anndata(
611
+ mdata[modality],
612
+ var_field=field,
613
+ obs_field=obs_fields.get(modality, Feature.name),
370
614
  mute=mute,
371
615
  organism=organism,
372
616
  )
373
-
374
- # link feature sets
375
- self._host._feature_sets = feature_sets
376
- self._host.save()
377
-
378
- def add_from_mudata(
379
- self,
380
- var_fields: dict[str, FieldAttr],
381
- obs_fields: dict[str, FieldAttr] = None,
382
- mute: bool = False,
383
- organism: str | Registry | None = None,
384
- ):
385
- """Add features from MuData."""
386
- if obs_fields is None:
387
- obs_fields = {}
388
- if isinstance(self._host, Artifact):
389
- assert self._host.accessor == "MuData"
390
- else:
391
- raise NotImplementedError()
392
-
393
- # parse and register features
394
- mdata = self._host.load()
395
- feature_sets = {}
396
- obs_features = features = Feature.from_values(mdata.obs.columns)
397
- if len(obs_features) > 0:
398
- feature_sets["obs"] = FeatureSet(features=features)
399
- for modality, field in var_fields.items():
400
- modality_fs = parse_feature_sets_from_anndata(
401
- mdata[modality],
402
- var_field=field,
403
- obs_field=obs_fields.get(modality, Feature.name),
404
- mute=mute,
405
- organism=organism,
406
- )
407
- for k, v in modality_fs.items():
408
- feature_sets[f"['{modality}'].{k}"] = v
409
-
410
- # link feature sets
411
- self._host._feature_sets = feature_sets
412
- self._host.save()
413
-
414
- def add_feature_set(self, feature_set: FeatureSet, slot: str):
415
- """Add new feature set to a slot.
416
-
417
- Args:
418
- feature_set: `FeatureSet` A feature set object.
419
- slot: `str` The access slot.
420
- """
421
- if self._host._state.adding:
422
- raise ValueError(
423
- "Please save the artifact or collection before adding a feature set!"
424
- )
425
- host_db = self._host._state.db
426
- feature_set.save(using=host_db)
427
- host_id_field = get_host_id_field(self._host)
428
- kwargs = {
429
- host_id_field: self._host.id,
430
- "featureset": feature_set,
431
- "slot": slot,
432
- }
433
- link_record = (
434
- self._host.feature_sets.through.objects.using(host_db)
435
- .filter(**kwargs)
436
- .one_or_none()
617
+ for k, v in modality_fs.items():
618
+ feature_sets[f"['{modality}'].{k}"] = v
619
+
620
+ # link feature sets
621
+ self._host._feature_sets = feature_sets
622
+ self._host.save()
623
+
624
+
625
+ def _add_from(self, data: HasFeatures, parents: bool = True):
626
+ """Transfer features from a artifact or collection."""
627
+ # This only covers feature sets, though.
628
+ using_key = settings._using_key
629
+ for slot, feature_set in data.features._feature_set_by_slot.items():
630
+ members = feature_set.members
631
+ if len(members) == 0:
632
+ continue
633
+ registry = members[0].__class__
634
+ # note here the features are transferred based on an unique field
635
+ field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
636
+ # TODO: get a default ID field for the registry
637
+ if hasattr(registry, "ontology_id") and parents:
638
+ field = "ontology_id"
639
+ elif hasattr(registry, "ensembl_gene_id"):
640
+ field = "ensembl_gene_id"
641
+ elif hasattr(registry, "uniprotkb_id"):
642
+ field = "uniprotkb_id"
643
+
644
+ if registry.__get_name_with_schema__() == "bionty.Organism":
645
+ parents = False
646
+ # this will be e.g. be a list of ontology_ids or uids
647
+ member_uids = list(members.values_list(field, flat=True))
648
+ # create records from ontology_id in order to populate parents
649
+ if field == "ontology_id" and len(member_uids) > 0 and parents:
650
+ # create from bionty
651
+ records = registry.from_values(member_uids, field=field)
652
+ if len(records) > 0:
653
+ save(records, parents=parents)
654
+ validated = registry.validate(member_uids, field=field, mute=True)
655
+ new_members_uids = list(compress(member_uids, ~validated))
656
+ new_members = members.filter(**{f"{field}__in": new_members_uids}).all()
657
+ n_new_members = len(new_members)
658
+ if n_new_members > 0:
659
+ mute = True if n_new_members > 10 else False
660
+ # transfer foreign keys needs to be run before transfer to default db
661
+ transfer_fk_to_default_db_bulk(new_members, using_key)
662
+ for feature in new_members:
663
+ # not calling save=True here as in labels, because want to
664
+ # bulk save below
665
+ # transfer_fk is set to False because they are already transferred
666
+ # in the previous step transfer_fk_to_default_db_bulk
667
+ transfer_to_default_db(feature, using_key, mute=mute, transfer_fk=False)
668
+ logger.info(f"saving {n_new_members} new {registry.__name__} records")
669
+ save(new_members, parents=parents)
670
+
671
+ # create a new feature set from feature values using the same uid
672
+ feature_set_self = FeatureSet.from_values(
673
+ member_uids, field=getattr(registry, field)
437
674
  )
438
- if link_record is None:
439
- self._host.feature_sets.through(**kwargs).save(using=host_db)
440
- if slot in self.feature_set_by_slot:
441
- logger.debug(f"replaced existing {slot} feature set")
442
- # this _feature_set_by_slot here is private
443
- self._feature_set_by_slot[slot] = feature_set # type: ignore
444
-
445
- def _add_from(self, data: Data, parents: bool = True):
446
- """Transfer features from a artifact or collection."""
447
- using_key = settings._using_key
448
- for slot, feature_set in data.features.feature_set_by_slot.items():
449
- print(slot)
450
- members = feature_set.members
451
- if len(members) == 0:
452
- continue
453
- registry = members[0].__class__
454
- # note here the features are transferred based on an unique field
455
- field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
456
- # TODO: get a default ID field for the registry
457
- if hasattr(registry, "ontology_id") and parents:
458
- field = "ontology_id"
459
- elif hasattr(registry, "ensembl_gene_id"):
460
- field = "ensembl_gene_id"
461
- elif hasattr(registry, "uniprotkb_id"):
462
- field = "uniprotkb_id"
463
-
464
- if registry.__get_name_with_schema__() == "bionty.Organism":
465
- parents = False
466
- # this will be e.g. be a list of ontology_ids or uids
467
- member_uids = list(members.values_list(field, flat=True))
468
- # create records from ontology_id in order to populate parents
469
- if field == "ontology_id" and len(member_uids) > 0 and parents:
470
- # create from bionty
471
- records = registry.from_values(member_uids, field=field)
472
- if len(records) > 0:
473
- save(records, parents=parents)
474
- validated = registry.validate(member_uids, field=field, mute=True)
475
- new_members_uids = list(compress(member_uids, ~validated))
476
- new_members = members.filter(**{f"{field}__in": new_members_uids}).all()
477
- n_new_members = len(new_members)
478
- if n_new_members > 0:
479
- mute = True if n_new_members > 10 else False
480
- # transfer foreign keys needs to be run before transfer to default db
481
- transfer_fk_to_default_db_bulk(new_members, using_key)
482
- for feature in new_members:
483
- # not calling save=True here as in labels, because want to
484
- # bulk save below
485
- # transfer_fk is set to False because they are already transferred
486
- # in the previous step transfer_fk_to_default_db_bulk
487
- transfer_to_default_db(
488
- feature, using_key, mute=mute, transfer_fk=False
489
- )
490
- logger.info(f"saving {n_new_members} new {registry.__name__} records")
491
- save(new_members, parents=parents)
492
-
493
- # create a new feature set from feature values using the same uid
494
- feature_set_self = FeatureSet.from_values(
495
- member_uids, field=getattr(registry, field)
496
- )
497
- if feature_set_self is None:
498
- if hasattr(registry, "organism"):
499
- logger.warning(
500
- f"FeatureSet is not transferred, check if organism is set correctly: {feature_set}"
501
- )
502
- continue
503
- # make sure the uid matches if featureset is composed of same features
504
- if feature_set_self.hash == feature_set.hash:
505
- feature_set_self.uid = feature_set.uid
506
- logger.info(f"saving {slot} featureset: {feature_set_self}")
507
- self._host.features.add_feature_set(feature_set_self, slot)
675
+ if feature_set_self is None:
676
+ if hasattr(registry, "organism"):
677
+ logger.warning(
678
+ f"FeatureSet is not transferred, check if organism is set correctly: {feature_set}"
679
+ )
680
+ continue
681
+ # make sure the uid matches if featureset is composed of same features
682
+ if feature_set_self.hash == feature_set.hash:
683
+ feature_set_self.uid = feature_set.uid
684
+ logger.info(f"saving {slot} featureset: {feature_set_self}")
685
+ self._host.features.add_feature_set(feature_set_self, slot)
686
+
687
+
688
+ FeatureManager.__init__ = __init__
689
+ FeatureManager.__repr__ = __repr__
690
+ FeatureManager.__getitem__ = __getitem__
691
+ FeatureManager.get_values = get_values
692
+ FeatureManager._feature_set_by_slot = _feature_set_by_slot
693
+ FeatureManager._accessor_by_registry = _accessor_by_registry
694
+ FeatureManager.add_values = add_values
695
+ FeatureManager.add_feature_set = add_feature_set
696
+ FeatureManager._add_set_from_df = _add_set_from_df
697
+ FeatureManager._add_set_from_anndata = _add_set_from_anndata
698
+ FeatureManager._add_set_from_mudata = _add_set_from_mudata
699
+ FeatureManager._add_from = _add_from
700
+ FeatureManager.filter = filter