lamindb 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. lamindb/__init__.py +52 -36
  2. lamindb/_finish.py +17 -10
  3. lamindb/_tracked.py +1 -1
  4. lamindb/base/__init__.py +3 -1
  5. lamindb/base/fields.py +40 -22
  6. lamindb/base/ids.py +1 -94
  7. lamindb/base/types.py +2 -0
  8. lamindb/base/uids.py +117 -0
  9. lamindb/core/_context.py +203 -102
  10. lamindb/core/_settings.py +38 -25
  11. lamindb/core/datasets/__init__.py +11 -4
  12. lamindb/core/datasets/_core.py +5 -5
  13. lamindb/core/datasets/_small.py +0 -93
  14. lamindb/core/datasets/mini_immuno.py +172 -0
  15. lamindb/core/loaders.py +1 -1
  16. lamindb/core/storage/_backed_access.py +100 -6
  17. lamindb/core/storage/_polars_lazy_df.py +51 -0
  18. lamindb/core/storage/_pyarrow_dataset.py +15 -30
  19. lamindb/core/storage/_tiledbsoma.py +29 -13
  20. lamindb/core/storage/objects.py +6 -0
  21. lamindb/core/subsettings/__init__.py +2 -0
  22. lamindb/core/subsettings/_annotation_settings.py +11 -0
  23. lamindb/curators/__init__.py +7 -3349
  24. lamindb/curators/_legacy.py +2056 -0
  25. lamindb/curators/core.py +1534 -0
  26. lamindb/errors.py +11 -0
  27. lamindb/examples/__init__.py +27 -0
  28. lamindb/examples/schemas/__init__.py +12 -0
  29. lamindb/examples/schemas/_anndata.py +25 -0
  30. lamindb/examples/schemas/_simple.py +19 -0
  31. lamindb/integrations/_vitessce.py +8 -5
  32. lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
  33. lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
  34. lamindb/migrations/0093_alter_schemacomponent_unique_together.py +16 -0
  35. lamindb/models/__init__.py +4 -1
  36. lamindb/models/_describe.py +21 -4
  37. lamindb/models/_feature_manager.py +382 -287
  38. lamindb/models/_label_manager.py +8 -2
  39. lamindb/models/artifact.py +177 -106
  40. lamindb/models/artifact_set.py +122 -0
  41. lamindb/models/collection.py +73 -52
  42. lamindb/models/core.py +1 -1
  43. lamindb/models/feature.py +51 -17
  44. lamindb/models/has_parents.py +69 -14
  45. lamindb/models/project.py +1 -1
  46. lamindb/models/query_manager.py +221 -22
  47. lamindb/models/query_set.py +247 -172
  48. lamindb/models/record.py +65 -247
  49. lamindb/models/run.py +4 -4
  50. lamindb/models/save.py +8 -2
  51. lamindb/models/schema.py +456 -184
  52. lamindb/models/transform.py +2 -2
  53. lamindb/models/ulabel.py +8 -5
  54. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/METADATA +6 -6
  55. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/RECORD +57 -43
  56. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/LICENSE +0 -0
  57. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/WHEEL +0 -0
@@ -5,3362 +5,20 @@
5
5
 
6
6
  DataFrameCurator
7
7
  AnnDataCurator
8
- MuDataCurator
9
8
  SpatialDataCurator
9
+ MuDataCurator
10
10
 
11
- Helper classes.
11
+ Modules.
12
12
 
13
13
  .. autosummary::
14
14
  :toctree: .
15
15
 
16
- Curator
17
- SlotsCurator
18
- CatManager
19
- CatLookup
20
- DataFrameCatManager
21
- AnnDataCatManager
22
- MuDataCatManager
23
- SpatialDataCatManager
24
- TiledbsomaCatManager
16
+ core
25
17
 
26
18
  """
27
19
 
28
- from __future__ import annotations
29
-
30
- import copy
31
- import re
32
- from itertools import chain
33
- from typing import TYPE_CHECKING, Any, Callable, Literal
34
-
35
- import anndata as ad
36
- import lamindb_setup as ln_setup
37
- import pandas as pd
38
- import pandera
39
- import pyarrow as pa
40
- from lamin_utils import colors, logger
41
- from lamindb_setup.core import deprecated
42
- from lamindb_setup.core._docs import doc_args
43
- from lamindb_setup.core.upath import UPath
44
-
45
- from lamindb.core._compat import is_package_installed
46
-
47
- if TYPE_CHECKING:
48
- from lamindb_setup.core.types import UPathStr
49
- from mudata import MuData
50
- from spatialdata import SpatialData
51
-
52
- from lamindb.core.types import ScverseDataStructures
53
- from lamindb.models import Record
54
- from lamindb.base.types import FieldAttr # noqa
55
- from lamindb.core._settings import settings
56
- from lamindb.models import (
57
- Artifact,
58
- Feature,
59
- Record,
60
- Run,
61
- Schema,
62
- ULabel,
20
+ from ._legacy import ( # backward compat
21
+ CellxGeneAnnDataCatManager,
22
+ PertAnnDataCatManager,
63
23
  )
64
- from lamindb.models.artifact import (
65
- add_labels,
66
- data_is_anndata,
67
- data_is_mudata,
68
- data_is_spatialdata,
69
- )
70
- from lamindb.models.feature import parse_dtype, parse_cat_dtype
71
- from lamindb.models._from_values import _format_values
72
-
73
- from ..errors import InvalidArgument, ValidationError
74
- from anndata import AnnData
75
-
76
- if TYPE_CHECKING:
77
- from collections.abc import Iterable, MutableMapping
78
- from typing import Any
79
-
80
- from lamindb_setup.core.types import UPathStr
81
-
82
- from lamindb.models.query_set import RecordList
83
-
84
-
85
- def strip_ansi_codes(text):
86
- # This pattern matches ANSI escape sequences
87
- ansi_pattern = re.compile(r"\x1b\[[0-9;]*m")
88
- return ansi_pattern.sub("", text)
89
-
90
-
91
- class CatLookup:
92
- """Lookup categories from the reference instance.
93
-
94
- Args:
95
- categoricals: A dictionary of categorical fields to lookup.
96
- slots: A dictionary of slot fields to lookup.
97
- public: Whether to lookup from the public instance. Defaults to False.
98
-
99
- Example::
100
-
101
- curator = ln.curators.DataFrameCurator(...)
102
- curator.cat.lookup()["cell_type"].alveolar_type_1_fibroblast_cell
103
-
104
- """
105
-
106
- def __init__(
107
- self,
108
- categoricals: list[Feature] | dict[str, FieldAttr],
109
- slots: dict[str, FieldAttr] = None,
110
- public: bool = False,
111
- sources: dict[str, Record] | None = None,
112
- ) -> None:
113
- slots = slots or {}
114
- if isinstance(categoricals, list):
115
- categoricals = {
116
- feature.name: parse_dtype(feature.dtype)[0]["field"]
117
- for feature in categoricals
118
- }
119
- self._categoricals = {**categoricals, **slots}
120
- self._public = public
121
- self._sources = sources
122
-
123
- def __getattr__(self, name):
124
- if name in self._categoricals:
125
- registry = self._categoricals[name].field.model
126
- if self._public and hasattr(registry, "public"):
127
- return registry.public(source=self._sources.get(name)).lookup()
128
- else:
129
- return registry.lookup()
130
- raise AttributeError(
131
- f'"{self.__class__.__name__}" object has no attribute "{name}"'
132
- )
133
-
134
- def __getitem__(self, name):
135
- if name in self._categoricals:
136
- registry = self._categoricals[name].field.model
137
- if self._public and hasattr(registry, "public"):
138
- return registry.public(source=self._sources.get(name)).lookup()
139
- else:
140
- return registry.lookup()
141
- raise AttributeError(
142
- f'"{self.__class__.__name__}" object has no attribute "{name}"'
143
- )
144
-
145
- def __repr__(self) -> str:
146
- if len(self._categoricals) > 0:
147
- getattr_keys = "\n ".join(
148
- [f".{key}" for key in self._categoricals if key.isidentifier()]
149
- )
150
- getitem_keys = "\n ".join(
151
- [str([key]) for key in self._categoricals if not key.isidentifier()]
152
- )
153
- ref = "public" if self._public else "registries"
154
- return (
155
- f"Lookup objects from the {colors.italic(ref)}:\n "
156
- f"{colors.green(getattr_keys)}\n "
157
- f"{colors.green(getitem_keys)}\n"
158
- 'Example:\n → categories = curator.lookup()["cell_type"]\n'
159
- " → categories.alveolar_type_1_fibroblast_cell\n\n"
160
- "To look up public ontologies, use .lookup(public=True)"
161
- )
162
- else: # pragma: no cover
163
- return colors.warning("No fields are found!")
164
-
165
-
166
- CAT_MANAGER_DOCSTRING = """Manage categoricals by updating registries."""
167
-
168
-
169
- SLOTS_DOCSTRING = """Curator objects by slot.
170
-
171
- .. versionadded:: 1.1.1
172
- """
173
-
174
-
175
- VALIDATE_DOCSTRING = """Validate dataset against Schema.
176
-
177
- Raises:
178
- lamindb.errors.ValidationError: If validation fails.
179
- """
180
-
181
- SAVE_ARTIFACT_DOCSTRING = """Save an annotated artifact.
182
-
183
- Args:
184
- key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
185
- description: A description.
186
- revises: Previous version of the artifact. Is an alternative way to passing `key` to trigger a new version.
187
- run: The run that creates the artifact.
188
-
189
- Returns:
190
- A saved artifact record.
191
- """
192
-
193
-
194
- class Curator:
195
- """Curator base class.
196
-
197
- A `Curator` object makes it easy to validate, standardize & annotate datasets.
198
-
199
- See:
200
- - :class:`~lamindb.curators.DataFrameCurator`
201
- - :class:`~lamindb.curators.AnnDataCurator`
202
- - :class:`~lamindb.curators.MuDataCurator`
203
- - :class:`~lamindb.curators.SpatialDataCurator`
204
-
205
- .. versionadded:: 1.1.0
206
- """
207
-
208
- def __init__(self, dataset: Any, schema: Schema | None = None):
209
- self._artifact: Artifact = None # pass the dataset as an artifact
210
- self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
211
- if isinstance(self._dataset, Artifact):
212
- self._artifact = self._dataset
213
- if self._artifact.otype in {
214
- "DataFrame",
215
- "AnnData",
216
- "MuData",
217
- "SpatialData",
218
- }:
219
- self._dataset = self._dataset.load()
220
- self._schema: Schema | None = schema
221
- self._is_validated: bool = False
222
- self._cat_manager: CatManager = None # is None for CatManager curators
223
-
224
- @doc_args(VALIDATE_DOCSTRING)
225
- def validate(self) -> bool | str:
226
- """{}""" # noqa: D415
227
- pass # pragma: no cover
228
-
229
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
230
- def save_artifact(
231
- self,
232
- *,
233
- key: str | None = None,
234
- description: str | None = None,
235
- revises: Artifact | None = None,
236
- run: Run | None = None,
237
- ) -> Artifact:
238
- """{}""" # noqa: D415
239
- # Note that this docstring has to be consistent with the Artifact()
240
- # constructor signature
241
- pass # pragma: no cover
242
-
243
-
244
- # default implementation for MuDataCurator and SpatialDataCurator
245
- class SlotsCurator(Curator):
246
- """Curator for a dataset with slots.
247
-
248
- Args:
249
- dataset: The dataset to validate & annotate.
250
- schema: A `Schema` object that defines the validation constraints.
251
-
252
- .. versionadded:: 1.3.0
253
- """
254
-
255
- def __init__(
256
- self,
257
- dataset: Any,
258
- schema: Schema,
259
- ) -> None:
260
- super().__init__(dataset=dataset, schema=schema)
261
- self._slots: dict[str, DataFrameCurator] = {}
262
-
263
- # used in MuDataCurator and SpatialDataCurator
264
- # in form of {table/modality_key: var_field}
265
- self._var_fields: dict[str, FieldAttr] = {}
266
- # in form of {table/modality_key: categoricals}
267
- self._cat_columns: dict[str, dict[str, CatColumn]] = {}
268
-
269
- @property
270
- @doc_args(SLOTS_DOCSTRING)
271
- def slots(self) -> dict[str, DataFrameCurator]:
272
- """{}""" # noqa: D415
273
- return self._slots
274
-
275
- @doc_args(VALIDATE_DOCSTRING)
276
- def validate(self) -> None:
277
- """{}""" # noqa: D415
278
- for slot, curator in self._slots.items():
279
- logger.info(f"validating slot {slot} ...")
280
- curator.validate()
281
-
282
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
283
- def save_artifact(
284
- self,
285
- *,
286
- key: str | None = None,
287
- description: str | None = None,
288
- revises: Artifact | None = None,
289
- run: Run | None = None,
290
- ) -> Artifact:
291
- """{}""" # noqa: D415
292
- if not self._is_validated:
293
- self.validate()
294
- if self._artifact is None:
295
- if data_is_mudata(self._dataset):
296
- self._artifact = Artifact.from_mudata(
297
- self._dataset,
298
- key=key,
299
- description=description,
300
- revises=revises,
301
- run=run,
302
- )
303
- elif data_is_spatialdata(self._dataset):
304
- self._artifact = Artifact.from_spatialdata(
305
- self._dataset,
306
- key=key,
307
- description=description,
308
- revises=revises,
309
- run=run,
310
- )
311
- self._artifact.schema = self._schema
312
- self._artifact.save()
313
- cat_columns = {}
314
- for curator in self._slots.values():
315
- for key, cat_column in curator._cat_manager._cat_columns.items():
316
- cat_columns[key] = cat_column
317
- return annotate_artifact( # type: ignore
318
- self._artifact,
319
- index_field=self._var_fields,
320
- schema=self._schema,
321
- cat_columns=cat_columns,
322
- )
323
-
324
-
325
- def check_dtype(expected_type) -> Callable:
326
- """Creates a check function for Pandera that validates a column's dtype.
327
-
328
- Args:
329
- expected_type: String identifier for the expected type ('int', 'float', or 'num')
330
-
331
- Returns:
332
- A function that checks if a series has the expected dtype
333
- """
334
-
335
- def check_function(series):
336
- if expected_type == "int":
337
- is_valid = pd.api.types.is_integer_dtype(series.dtype)
338
- elif expected_type == "float":
339
- is_valid = pd.api.types.is_float_dtype(series.dtype)
340
- elif expected_type == "num":
341
- is_valid = pd.api.types.is_numeric_dtype(series.dtype)
342
- return is_valid
343
-
344
- return check_function
345
-
346
-
347
- class DataFrameCurator(Curator):
348
- # the example in the docstring is tested in test_curators_quickstart_example
349
- """Curator for `DataFrame`.
350
-
351
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
352
-
353
- .. versionadded:: 1.1.0
354
-
355
- Args:
356
- dataset: The DataFrame-like object to validate & annotate.
357
- schema: A `Schema` object that defines the validation constraints.
358
-
359
- Example::
360
-
361
- import lamindb as ln
362
- import bionty as bt
363
-
364
- # define valid labels
365
- perturbation = ln.ULabel(name="Perturbation", is_type=True).save()
366
- ln.ULabel(name="DMSO", type=perturbation).save()
367
- ln.ULabel(name="IFNG", type=perturbation).save()
368
- bt.CellType.from_source(name="B cell").save()
369
- bt.CellType.from_source(name="T cell").save()
370
-
371
- # define schema
372
- schema = ln.Schema(
373
- name="small_dataset1_obs_level_metadata",
374
- features=[
375
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
376
- ln.Feature(name="sample_note", dtype=str).save(),
377
- ln.Feature(name="cell_type_by_expert", dtype=bt.CellType).save(),
378
- ln.Feature(name="cell_type_by_model", dtype=bt.CellType).save(),
379
- ],
380
- ).save()
381
-
382
- # curate a DataFrame
383
- df = datasets.small_dataset1(otype="DataFrame")
384
- curator = ln.curators.DataFrameCurator(df, schema)
385
- artifact = curator.save_artifact(key="example_datasets/dataset1.parquet")
386
- assert artifact.schema == schema
387
- """
388
-
389
- def __init__(
390
- self,
391
- dataset: pd.DataFrame | Artifact,
392
- schema: Schema,
393
- ) -> None:
394
- super().__init__(dataset=dataset, schema=schema)
395
- categoricals = []
396
- features = []
397
- feature_ids: set[int] = set()
398
- if schema.flexible and isinstance(self._dataset, pd.DataFrame):
399
- features += Feature.filter(name__in=self._dataset.keys()).list()
400
- feature_ids = {feature.id for feature in features}
401
- if schema.n > 0:
402
- schema_features = schema.features.all().list()
403
- if feature_ids:
404
- features.extend(
405
- feature
406
- for feature in schema_features
407
- if feature.id not in feature_ids
408
- )
409
- else:
410
- features.extend(schema_features)
411
- else:
412
- assert schema.itype is not None # noqa: S101
413
- if features:
414
- # populate features
415
- pandera_columns = {}
416
- if schema.minimal_set:
417
- optional_feature_uids = set(schema.optionals.get_uids())
418
- for feature in features:
419
- if schema.minimal_set:
420
- required = feature.uid not in optional_feature_uids
421
- else:
422
- required = False
423
- if feature.dtype in {"int", "float", "num"}:
424
- dtype = (
425
- self._dataset[feature.name].dtype
426
- if feature.name in self._dataset.columns
427
- else None
428
- )
429
- pandera_columns[feature.name] = pandera.Column(
430
- dtype=None,
431
- checks=pandera.Check(
432
- check_dtype(feature.dtype),
433
- element_wise=False,
434
- error=f"Column '{feature.name}' failed dtype check for '{feature.dtype}': got {dtype}",
435
- ),
436
- nullable=feature.nullable,
437
- coerce=feature.coerce_dtype,
438
- required=required,
439
- )
440
- else:
441
- pandera_dtype = (
442
- feature.dtype
443
- if not feature.dtype.startswith("cat")
444
- else "category"
445
- )
446
- pandera_columns[feature.name] = pandera.Column(
447
- pandera_dtype,
448
- nullable=feature.nullable,
449
- coerce=feature.coerce_dtype,
450
- required=required,
451
- )
452
- if feature.dtype.startswith("cat"):
453
- # validate categoricals if the column is required or if the column is present
454
- if required or feature.name in self._dataset.columns:
455
- categoricals.append(feature)
456
- self._pandera_schema = pandera.DataFrameSchema(
457
- pandera_columns,
458
- coerce=schema.coerce_dtype,
459
- strict=schema.maximal_set,
460
- ordered=schema.ordered_set,
461
- )
462
- self._cat_manager = DataFrameCatManager(
463
- self._dataset,
464
- columns=parse_cat_dtype(schema.itype, is_itype=True)["field"],
465
- categoricals=categoricals,
466
- )
467
-
468
- @property
469
- @doc_args(CAT_MANAGER_DOCSTRING)
470
- def cat(self) -> CatManager:
471
- """{}""" # noqa: D415
472
- return self._cat_manager
473
-
474
- def standardize(self) -> None:
475
- """Standardize the dataset.
476
-
477
- - Adds missing columns for features
478
- - Fills missing values for features with default values
479
- """
480
- for feature in self._schema.members:
481
- if feature.name not in self._dataset.columns:
482
- if feature.default_value is not None or feature.nullable:
483
- fill_value = (
484
- feature.default_value
485
- if feature.default_value is not None
486
- else pd.NA
487
- )
488
- if feature.dtype.startswith("cat"):
489
- self._dataset[feature.name] = pd.Categorical(
490
- [fill_value] * len(self._dataset)
491
- )
492
- else:
493
- self._dataset[feature.name] = fill_value
494
- logger.important(
495
- f"added column {feature.name} with fill value {fill_value}"
496
- )
497
- else:
498
- raise ValidationError(
499
- f"Missing column {feature.name} cannot be added because is not nullable and has no default value"
500
- )
501
- else:
502
- if feature.default_value is not None:
503
- if isinstance(
504
- self._dataset[feature.name].dtype, pd.CategoricalDtype
505
- ):
506
- if (
507
- feature.default_value
508
- not in self._dataset[feature.name].cat.categories
509
- ):
510
- self._dataset[feature.name] = self._dataset[
511
- feature.name
512
- ].cat.add_categories(feature.default_value)
513
- self._dataset[feature.name] = self._dataset[feature.name].fillna(
514
- feature.default_value
515
- )
516
-
517
- def _cat_manager_validate(self) -> None:
518
- self._cat_manager.validate()
519
- if self._cat_manager._is_validated:
520
- self._is_validated = True
521
- else:
522
- self._is_validated = False
523
- raise ValidationError(self._cat_manager._validate_category_error_messages)
524
-
525
- @doc_args(VALIDATE_DOCSTRING)
526
- def validate(self) -> None:
527
- """{}""" # noqa: D415
528
- if self._schema.n > 0:
529
- try:
530
- # first validate through pandera
531
- self._pandera_schema.validate(self._dataset)
532
- # then validate lamindb categoricals
533
- self._cat_manager_validate()
534
- except pandera.errors.SchemaError as err:
535
- self._is_validated = False
536
- # .exconly() doesn't exist on SchemaError
537
- raise ValidationError(str(err)) from err
538
- else:
539
- self._cat_manager_validate()
540
-
541
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
542
- def save_artifact(
543
- self,
544
- *,
545
- key: str | None = None,
546
- description: str | None = None,
547
- revises: Artifact | None = None,
548
- run: Run | None = None,
549
- ) -> Artifact:
550
- """{}""" # noqa: D415
551
- if not self._is_validated:
552
- self.validate() # raises ValidationError if doesn't validate
553
- result = parse_cat_dtype(self._schema.itype, is_itype=True)
554
- if self._artifact is None:
555
- self._artifact = Artifact.from_df(
556
- self._dataset,
557
- key=key,
558
- description=description,
559
- revises=revises,
560
- run=run,
561
- )
562
- self._artifact.schema = self._schema
563
- self._artifact.save()
564
- return annotate_artifact( # type: ignore
565
- self._artifact,
566
- index_field=result["field"],
567
- schema=self._schema,
568
- cat_columns=self._cat_manager._cat_columns,
569
- )
570
-
571
-
572
- class AnnDataCurator(SlotsCurator):
573
- # the example in the docstring is tested in test_curators_quickstart_example
574
- """Curator for `AnnData`.
575
-
576
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
577
-
578
- .. versionadded:: 1.1.0
579
-
580
- Args:
581
- dataset: The AnnData-like object to validate & annotate.
582
- schema: A `Schema` object that defines the validation constraints.
583
-
584
- Example::
585
-
586
- import lamindb as ln
587
- import bionty as bt
588
-
589
- # define valid labels
590
- perturbation = ln.ULabel(name="Perturbation", is_type=True).save()
591
- ln.ULabel(name="DMSO", type=perturbation).save()
592
- ln.ULabel(name="IFNG", type=perturbation).save()
593
- bt.CellType.from_source(name="B cell").save()
594
- bt.CellType.from_source(name="T cell").save()
595
-
596
- # define obs schema
597
- obs_schema = ln.Schema(
598
- name="small_dataset1_obs_level_metadata",
599
- features=[
600
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
601
- ln.Feature(name="sample_note", dtype=str).save(),
602
- ln.Feature(name="cell_type_by_expert", dtype=bt.CellType).save(),
603
- ln.Feature(name="cell_type_by_model", dtype=bt.CellType).save(),
604
- ],
605
- ).save()
606
-
607
- # define var schema
608
- var_schema = ln.Schema(
609
- name="scRNA_seq_var_schema",
610
- itype=bt.Gene.ensembl_gene_id,
611
- dtype=int,
612
- ).save()
613
-
614
- # define composite schema
615
- anndata_schema = ln.Schema(
616
- name="small_dataset1_anndata_schema",
617
- otype="AnnData",
618
- components={"obs": obs_schema, "var": var_schema},
619
- ).save()
620
-
621
- # curate an AnnData
622
- adata = ln.core.datasets.small_dataset1(otype="AnnData")
623
- curator = ln.curators.AnnDataCurator(adata, anndata_schema)
624
- artifact = curator.save_artifact(key="example_datasets/dataset1.h5ad")
625
- assert artifact.schema == anndata_schema
626
- """
627
-
628
- def __init__(
629
- self,
630
- dataset: AnnData | Artifact,
631
- schema: Schema,
632
- ) -> None:
633
- super().__init__(dataset=dataset, schema=schema)
634
- if not data_is_anndata(self._dataset):
635
- raise InvalidArgument("dataset must be AnnData-like.")
636
- if schema.otype != "AnnData":
637
- raise InvalidArgument("Schema otype must be 'AnnData'.")
638
- # TODO: also support slots other than obs and var
639
- self._slots = {
640
- slot: DataFrameCurator(
641
- (
642
- getattr(self._dataset, slot).T
643
- if slot == "var"
644
- else getattr(self._dataset, slot)
645
- ),
646
- slot_schema,
647
- )
648
- for slot, slot_schema in schema.slots.items()
649
- if slot in {"obs", "var", "uns"}
650
- }
651
- # TODO: better way to handle this!
652
- if "var" in self._slots:
653
- self._slots["var"]._cat_manager._cat_columns["var_index"] = self._slots[
654
- "var"
655
- ]._cat_manager._cat_columns.pop("columns")
656
- self._slots["var"]._cat_manager._cat_columns["var_index"]._key = "var_index"
657
-
658
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
659
- def save_artifact(
660
- self,
661
- *,
662
- key: str | None = None,
663
- description: str | None = None,
664
- revises: Artifact | None = None,
665
- run: Run | None = None,
666
- ) -> Artifact:
667
- """{}""" # noqa: D415
668
- if not self._is_validated:
669
- self.validate()
670
- if self._artifact is None:
671
- self._artifact = Artifact.from_anndata(
672
- self._dataset,
673
- key=key,
674
- description=description,
675
- revises=revises,
676
- run=run,
677
- )
678
- self._artifact.schema = self._schema
679
- self._artifact.save()
680
- return annotate_artifact( # type: ignore
681
- self._artifact,
682
- cat_columns=(
683
- self.slots["obs"]._cat_manager._cat_columns
684
- if "obs" in self.slots
685
- else {}
686
- ),
687
- index_field=(
688
- parse_cat_dtype(self.slots["var"]._schema.itype, is_itype=True)["field"]
689
- if "var" in self._slots
690
- else None
691
- ),
692
- schema=self._schema,
693
- )
694
-
695
-
696
- def _assign_var_fields_categoricals_multimodal(
697
- modality: str | None,
698
- slot_type: str,
699
- slot: str,
700
- slot_schema: Schema,
701
- var_fields: dict[str, FieldAttr],
702
- cat_columns: dict[str, dict[str, CatColumn]],
703
- slots: dict[str, DataFrameCurator],
704
- ) -> None:
705
- """Assigns var_fields and categoricals for multimodal data curators."""
706
- if modality is not None:
707
- # Makes sure that all tables are present
708
- var_fields[modality] = None
709
- cat_columns[modality] = {}
710
-
711
- if slot_type == "var":
712
- var_field = parse_cat_dtype(slot_schema.itype, is_itype=True)["field"]
713
- if modality is None:
714
- # This should rarely/never be used since tables should have different var fields
715
- var_fields[slot] = var_field # pragma: no cover
716
- else:
717
- # Note that this is NOT nested since the nested key is always "var"
718
- var_fields[modality] = var_field
719
- else:
720
- obs_fields = slots[slot]._cat_manager._cat_columns
721
- if modality is None:
722
- cat_columns[slot] = obs_fields
723
- else:
724
- # Note that this is NOT nested since the nested key is always "obs"
725
- cat_columns[modality] = obs_fields
726
-
727
-
728
- class MuDataCurator(SlotsCurator):
729
- # the example in the docstring is tested in test_curators_quickstart_example
730
- """Curator for `MuData`.
731
-
732
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
733
-
734
- .. versionadded:: 1.3.0
735
-
736
- Args:
737
- dataset: The MuData-like object to validate & annotate.
738
- schema: A `Schema` object that defines the validation constraints.
739
-
740
- Example::
741
-
742
- import lamindb as ln
743
- import bionty as bt
744
-
745
- # define the global obs schema
746
- obs_schema = ln.Schema(
747
- name="mudata_papalexi21_subset_obs_schema",
748
- features=[
749
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
750
- ln.Feature(name="replicate", dtype="cat[ULabel[Replicate]]").save(),
751
- ],
752
- ).save()
753
-
754
- # define the ['rna'].obs schema
755
- obs_schema_rna = ln.Schema(
756
- name="mudata_papalexi21_subset_rna_obs_schema",
757
- features=[
758
- ln.Feature(name="nCount_RNA", dtype=int).save(),
759
- ln.Feature(name="nFeature_RNA", dtype=int).save(),
760
- ln.Feature(name="percent.mito", dtype=float).save(),
761
- ],
762
- coerce_dtype=True,
763
- ).save()
764
-
765
- # define the ['hto'].obs schema
766
- obs_schema_hto = ln.Schema(
767
- name="mudata_papalexi21_subset_hto_obs_schema",
768
- features=[
769
- ln.Feature(name="nCount_HTO", dtype=int).save(),
770
- ln.Feature(name="nFeature_HTO", dtype=int).save(),
771
- ln.Feature(name="technique", dtype=bt.ExperimentalFactor).save(),
772
- ],
773
- coerce_dtype=True,
774
- ).save()
775
-
776
- # define ['rna'].var schema
777
- var_schema_rna = ln.Schema(
778
- name="mudata_papalexi21_subset_rna_var_schema",
779
- itype=bt.Gene.symbol,
780
- dtype=float,
781
- ).save()
782
-
783
- # define composite schema
784
- mudata_schema = ln.Schema(
785
- name="mudata_papalexi21_subset_mudata_schema",
786
- otype="MuData",
787
- components={
788
- "obs": obs_schema,
789
- "rna:obs": obs_schema_rna,
790
- "hto:obs": obs_schema_hto,
791
- "rna:var": var_schema_rna,
792
- },
793
- ).save()
794
-
795
- # curate a MuData
796
- mdata = ln.core.datasets.mudata_papalexi21_subset()
797
- bt.settings.organism = "human" # set the organism
798
- curator = ln.curators.MuDataCurator(mdata, mudata_schema)
799
- artifact = curator.save_artifact(key="example_datasets/mudata_papalexi21_subset.h5mu")
800
- assert artifact.schema == mudata_schema
801
- """
802
-
803
- def __init__(
804
- self,
805
- dataset: MuData | Artifact,
806
- schema: Schema,
807
- ) -> None:
808
- super().__init__(dataset=dataset, schema=schema)
809
- if not data_is_mudata(self._dataset):
810
- raise InvalidArgument("dataset must be MuData-like.")
811
- if schema.otype != "MuData":
812
- raise InvalidArgument("Schema otype must be 'MuData'.")
813
-
814
- for slot, slot_schema in schema.slots.items():
815
- # Assign to _slots
816
- if ":" in slot:
817
- modality, modality_slot = slot.split(":")
818
- schema_dataset = self._dataset.__getitem__(modality)
819
- else:
820
- modality, modality_slot = None, slot
821
- schema_dataset = self._dataset
822
- self._slots[slot] = DataFrameCurator(
823
- (
824
- getattr(schema_dataset, modality_slot).T
825
- if modality_slot == "var"
826
- else getattr(schema_dataset, modality_slot)
827
- ),
828
- slot_schema,
829
- )
830
- _assign_var_fields_categoricals_multimodal(
831
- modality=modality,
832
- slot_type=modality_slot,
833
- slot=slot,
834
- slot_schema=slot_schema,
835
- var_fields=self._var_fields,
836
- cat_columns=self._cat_columns,
837
- slots=self._slots,
838
- )
839
-
840
- # for consistency with BaseCatManager
841
- self._columns_field = self._var_fields
842
-
843
-
844
- class SpatialDataCurator(SlotsCurator):
845
- # the example in the docstring is tested in test_curators_quickstart_example
846
- """Curator for `SpatialData`.
847
-
848
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
849
-
850
- .. versionadded:: 1.3.0
851
-
852
- Args:
853
- dataset: The SpatialData-like object to validate & annotate.
854
- schema: A `Schema` object that defines the validation constraints.
855
-
856
- Example::
857
-
858
- import lamindb as ln
859
- import bionty as bt
860
-
861
- # define sample schema
862
- sample_schema = ln.Schema(
863
- name="blobs_sample_level_metadata",
864
- features=[
865
- ln.Feature(name="assay", dtype=bt.ExperimentalFactor).save(),
866
- ln.Feature(name="disease", dtype=bt.Disease).save(),
867
- ln.Feature(name="development_stage", dtype=bt.DevelopmentalStage).save(),
868
- ],
869
- coerce_dtype=True
870
- ).save()
871
-
872
- # define table obs schema
873
- blobs_obs_schema = ln.Schema(
874
- name="blobs_obs_level_metadata",
875
- features=[
876
- ln.Feature(name="sample_region", dtype="str").save(),
877
- ],
878
- coerce_dtype=True
879
- ).save()
880
-
881
- # define table var schema
882
- blobs_var_schema = ln.Schema(
883
- name="blobs_var_schema",
884
- itype=bt.Gene.ensembl_gene_id,
885
- dtype=int
886
- ).save()
887
-
888
- # define composite schema
889
- spatialdata_schema = ln.Schema(
890
- name="blobs_spatialdata_schema",
891
- otype="SpatialData",
892
- components={
893
- "sample": sample_schema,
894
- "table:obs": blobs_obs_schema,
895
- "table:var": blobs_var_schema,
896
- }).save()
897
-
898
- # curate a SpatialData
899
- spatialdata = ln.core.datasets.spatialdata_blobs()
900
- curator = ln.curators.SpatialDataCurator(spatialdata, spatialdata_schema)
901
- try:
902
- curator.validate()
903
- except ln.errors.ValidationError as error:
904
- print(error)
905
-
906
- # validate again (must pass now) and save artifact
907
- artifact = curator.save_artifact(key="example_datasets/spatialdata1.zarr")
908
- assert artifact.schema == spatialdata_schema
909
- """
910
-
911
- def __init__(
912
- self,
913
- dataset: SpatialData | Artifact,
914
- schema: Schema,
915
- *,
916
- sample_metadata_key: str | None = "sample",
917
- ) -> None:
918
- super().__init__(dataset=dataset, schema=schema)
919
- if not data_is_spatialdata(self._dataset):
920
- raise InvalidArgument("dataset must be SpatialData-like.")
921
- if schema.otype != "SpatialData":
922
- raise InvalidArgument("Schema otype must be 'SpatialData'.")
923
-
924
- for slot, slot_schema in schema.slots.items():
925
- # Assign to _slots
926
- if ":" in slot:
927
- table_key, table_slot = slot.split(":")
928
- schema_dataset = self._dataset.tables.__getitem__(table_key)
929
- # sample metadata (does not have a `:` separator)
930
- else:
931
- table_key = None
932
- table_slot = slot
933
- schema_dataset = self._dataset.get_attrs(
934
- key=sample_metadata_key, return_as="df", flatten=True
935
- )
936
-
937
- self._slots[slot] = DataFrameCurator(
938
- (
939
- getattr(schema_dataset, table_slot).T
940
- if table_slot == "var"
941
- else (
942
- getattr(schema_dataset, table_slot)
943
- if table_slot != sample_metadata_key
944
- else schema_dataset
945
- ) # just take the schema_dataset if it's the sample metadata key
946
- ),
947
- slot_schema,
948
- )
949
-
950
- _assign_var_fields_categoricals_multimodal(
951
- modality=table_key,
952
- slot_type=table_slot,
953
- slot=slot,
954
- slot_schema=slot_schema,
955
- var_fields=self._var_fields,
956
- cat_columns=self._cat_columns,
957
- slots=self._slots,
958
- )
959
-
960
- # for consistency with BaseCatManager
961
- self._columns_field = self._var_fields
962
-
963
-
964
- class CatColumn:
965
- """Categorical column for `DataFrame`.
966
-
967
- Args:
968
- values_getter: A callable or iterable that returns the values to validate.
969
- field: The field to validate against.
970
- key: The name of the column to validate. Only used for logging.
971
- values_setter: A callable that sets the values.
972
- source: The source to validate against.
973
- """
974
-
975
- def __init__(
976
- self,
977
- values_getter: Callable | Iterable[str],
978
- field: FieldAttr,
979
- key: str,
980
- values_setter: Callable | None = None,
981
- source: Record | None = None,
982
- feature: Feature | None = None,
983
- ) -> None:
984
- self._values_getter = values_getter
985
- self._values_setter = values_setter
986
- self._field = field
987
- self._key = key
988
- self._source = source
989
- self._organism = None
990
- self._validated: None | list[str] = None
991
- self._non_validated: None | list[str] = None
992
- self._synonyms: None | dict[str, str] = None
993
- self.feature = feature
994
- self.labels = None
995
- if hasattr(field.field.model, "_name_field"):
996
- label_ref_is_name = field.field.name == field.field.model._name_field
997
- else:
998
- label_ref_is_name = field.field.name == "name"
999
- self.label_ref_is_name = label_ref_is_name
1000
-
1001
- @property
1002
- def values(self):
1003
- """Get the current values using the getter function."""
1004
- if callable(self._values_getter):
1005
- return self._values_getter()
1006
- return self._values_getter
1007
-
1008
- @values.setter
1009
- def values(self, new_values):
1010
- """Set new values using the setter function if available."""
1011
- if callable(self._values_setter):
1012
- self._values_setter(new_values)
1013
- else:
1014
- # If values_getter is not callable, it's a direct reference we can update
1015
- self._values_getter = new_values
1016
-
1017
- @property
1018
- def is_validated(self) -> bool:
1019
- """Return whether the column is validated."""
1020
- return len(self._non_validated) == 0
1021
-
1022
- def _replace_synonyms(self) -> list[str]:
1023
- """Replace synonyms in the column with standardized values."""
1024
- syn_mapper = self._synonyms
1025
- # replace the values in df
1026
- std_values = self.values.map(
1027
- lambda unstd_val: syn_mapper.get(unstd_val, unstd_val)
1028
- )
1029
- # remove the standardized values from self.non_validated
1030
- non_validated = [i for i in self._non_validated if i not in syn_mapper]
1031
- if len(non_validated) == 0:
1032
- self._non_validated = []
1033
- else:
1034
- self._non_validated = non_validated # type: ignore
1035
- # logging
1036
- n = len(syn_mapper)
1037
- if n > 0:
1038
- syn_mapper_print = _format_values(
1039
- [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
1040
- )
1041
- s = "s" if n > 1 else ""
1042
- logger.success(
1043
- f'standardized {n} synonym{s} in "{self._key}": {colors.green(syn_mapper_print)}'
1044
- )
1045
- return std_values
1046
-
1047
- def _add_validated(self) -> tuple[list, list]:
1048
- """Save features or labels records in the default instance."""
1049
- from lamindb.models.save import save as ln_save
1050
-
1051
- registry = self._field.field.model
1052
- field_name = self._field.field.name
1053
- model_field = registry.__get_name_with_module__()
1054
- filter_kwargs = get_current_filter_kwargs(
1055
- registry, {"organism": self._organism, "source": self._source}
1056
- )
1057
- values = [i for i in self.values if isinstance(i, str) and i]
1058
- if not values:
1059
- return [], []
1060
-
1061
- # inspect the default instance and save validated records from public
1062
- existing_and_public_records = registry.from_values(
1063
- list(values), field=self._field, **filter_kwargs, mute=True
1064
- )
1065
- existing_and_public_labels = [
1066
- getattr(r, field_name) for r in existing_and_public_records
1067
- ]
1068
- # public records that are not already in the database
1069
- public_records = [r for r in existing_and_public_records if r._state.adding]
1070
- # here we check to only save the public records if they are from the specified source
1071
- # we check the uid because r.source and source can be from different instances
1072
- if self._source:
1073
- public_records = [
1074
- r for r in public_records if r.source.uid == self._source.uid
1075
- ]
1076
- if len(public_records) > 0:
1077
- logger.info(f"saving validated records of '{self._key}'")
1078
- ln_save(public_records)
1079
- labels_saved_public = [getattr(r, field_name) for r in public_records]
1080
- # log the saved public labels
1081
- # the term "transferred" stresses that this is always in the context of transferring
1082
- # labels from a public ontology or a different instance to the present instance
1083
- if len(labels_saved_public) > 0:
1084
- s = "s" if len(labels_saved_public) > 1 else ""
1085
- logger.success(
1086
- f'added {len(labels_saved_public)} record{s} {colors.green("from_public")} with {model_field} for "{self._key}": {_format_values(labels_saved_public)}'
1087
- )
1088
- self.labels = existing_and_public_records
1089
-
1090
- # non-validated records from the default instance
1091
- non_validated_labels = [
1092
- i for i in values if i not in existing_and_public_labels
1093
- ]
1094
-
1095
- # validated, non-validated
1096
- return existing_and_public_labels, non_validated_labels
1097
-
1098
- def _add_new(
1099
- self,
1100
- values: list[str],
1101
- df: pd.DataFrame | None = None, # remove when all users use schema
1102
- dtype: str | None = None,
1103
- **create_kwargs,
1104
- ) -> None:
1105
- """Add new labels to the registry."""
1106
- from lamindb.models.save import save as ln_save
1107
-
1108
- registry = self._field.field.model
1109
- field_name = self._field.field.name
1110
- non_validated_records: RecordList[Any] = [] # type: ignore
1111
- if df is not None and registry == Feature:
1112
- nonval_columns = Feature.inspect(df.columns, mute=True).non_validated
1113
- non_validated_records = Feature.from_df(df.loc[:, nonval_columns])
1114
- else:
1115
- if (
1116
- self._organism
1117
- and hasattr(registry, "organism")
1118
- and registry._meta.get_field("organism").is_relation
1119
- ):
1120
- # make sure organism record is saved to the current instance
1121
- create_kwargs["organism"] = _save_organism(name=self._organism)
1122
-
1123
- for value in values:
1124
- init_kwargs = {field_name: value}
1125
- if registry == Feature:
1126
- init_kwargs["dtype"] = "cat" if dtype is None else dtype
1127
- non_validated_records.append(registry(**init_kwargs, **create_kwargs))
1128
- if len(non_validated_records) > 0:
1129
- ln_save(non_validated_records)
1130
- model_field = colors.italic(registry.__get_name_with_module__())
1131
- s = "s" if len(values) > 1 else ""
1132
- logger.success(
1133
- f'added {len(values)} record{s} with {model_field} for "{self._key}": {_format_values(values)}'
1134
- )
1135
-
1136
- def _validate(
1137
- self,
1138
- values: list[str],
1139
- curator: CatManager | None = None, # TODO: not yet used
1140
- ) -> tuple[list[str], dict]:
1141
- """Validate ontology terms using LaminDB registries."""
1142
- registry = self._field.field.model
1143
- field_name = self._field.field.name
1144
- model_field = f"{registry.__name__}.{field_name}"
1145
-
1146
- def _log_mapping_info():
1147
- logger.indent = ""
1148
- logger.info(f'mapping "{self._key}" on {colors.italic(model_field)}')
1149
- logger.indent = " "
1150
-
1151
- kwargs_current = get_current_filter_kwargs(
1152
- registry, {"organism": self._organism, "source": self._source}
1153
- )
1154
-
1155
- # inspect values from the default instance, excluding public
1156
- inspect_result = registry.inspect(
1157
- values, field=self._field, mute=True, from_source=False, **kwargs_current
1158
- )
1159
- non_validated = inspect_result.non_validated
1160
- syn_mapper = inspect_result.synonyms_mapper
1161
-
1162
- # inspect the non-validated values from public (BioRecord only)
1163
- values_validated = []
1164
- if hasattr(registry, "public"):
1165
- public_records = registry.from_values(
1166
- non_validated,
1167
- field=self._field,
1168
- mute=True,
1169
- **kwargs_current,
1170
- )
1171
- values_validated += [getattr(r, field_name) for r in public_records]
1172
-
1173
- # logging messages
1174
- non_validated_hint_print = f'.add_new_from("{self._key}")'
1175
- non_validated = [i for i in non_validated if i not in values_validated]
1176
- n_non_validated = len(non_validated)
1177
- if n_non_validated == 0:
1178
- logger.indent = ""
1179
- logger.success(
1180
- f'"{self._key}" is validated against {colors.italic(model_field)}'
1181
- )
1182
- return [], {}
1183
- else:
1184
- are = "is" if n_non_validated == 1 else "are"
1185
- s = "" if n_non_validated == 1 else "s"
1186
- print_values = _format_values(non_validated)
1187
- warning_message = f"{colors.red(f'{n_non_validated} term{s}')} {are} not validated: {colors.red(print_values)}\n"
1188
- if syn_mapper:
1189
- s = "" if len(syn_mapper) == 1 else "s"
1190
- syn_mapper_print = _format_values(
1191
- [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
1192
- )
1193
- hint_msg = f'.standardize("{self._key}")'
1194
- warning_message += f" {colors.yellow(f'{len(syn_mapper)} synonym{s}')} found: {colors.yellow(syn_mapper_print)}\n → curate synonyms via {colors.cyan(hint_msg)}"
1195
- if n_non_validated > len(syn_mapper):
1196
- if syn_mapper:
1197
- warning_message += "\n for remaining terms:\n"
1198
- warning_message += f" → fix typos, remove non-existent values, or save terms via {colors.cyan(non_validated_hint_print)}"
1199
-
1200
- if logger.indent == "":
1201
- _log_mapping_info()
1202
- logger.warning(warning_message)
1203
- if curator is not None:
1204
- curator._validate_category_error_messages = strip_ansi_codes(
1205
- warning_message
1206
- )
1207
- logger.indent = ""
1208
- return non_validated, syn_mapper
1209
-
1210
- def validate(self) -> None:
1211
- """Validate the column."""
1212
- # add source-validated values to the registry
1213
- self._validated, self._non_validated = self._add_validated()
1214
- self._non_validated, self._synonyms = self._validate(values=self._non_validated)
1215
- # always register new Features if they are columns
1216
- if self._key == "columns" and self._field == Feature.name:
1217
- self.add_new()
1218
-
1219
- def standardize(self) -> None:
1220
- """Standardize the column."""
1221
- registry = self._field.field.model
1222
- if not hasattr(registry, "standardize"):
1223
- return self.values
1224
- if self._synonyms is None:
1225
- self.validate()
1226
- # get standardized values
1227
- std_values = self._replace_synonyms()
1228
- # update non_validated values
1229
- self._non_validated = [
1230
- i for i in self._non_validated if i not in self._synonyms.keys()
1231
- ]
1232
- # remove synonyms since they are now standardized
1233
- self._synonyms = {}
1234
- # update the values with the standardized values
1235
- self.values = std_values
1236
-
1237
- def add_new(self, **create_kwargs) -> None:
1238
- """Add new values to the registry."""
1239
- if self._non_validated is None:
1240
- self.validate()
1241
- if len(self._synonyms) > 0:
1242
- # raise error because .standardize modifies the input dataset
1243
- raise ValidationError(
1244
- "Please run `.standardize()` before adding new values."
1245
- )
1246
- self._add_new(
1247
- values=self._non_validated,
1248
- **create_kwargs,
1249
- )
1250
- # remove the non_validated values since they are now registered
1251
- self._non_validated = []
1252
-
1253
-
1254
- class CatManager:
1255
- """Manage categoricals by updating registries.
1256
-
1257
- This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
1258
-
1259
- If you find non-validated values, you have several options:
1260
-
1261
- - new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.add_new_from`
1262
- - non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.non_validated` and addressed manually
1263
- """
1264
-
1265
- def __init__(self, *, dataset, categoricals, sources, columns_field=None):
1266
- # the below is shared with Curator
1267
- self._artifact: Artifact = None # pass the dataset as an artifact
1268
- self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
1269
- if isinstance(self._dataset, Artifact):
1270
- self._artifact = self._dataset
1271
- if self._artifact.otype in {"DataFrame", "AnnData"}:
1272
- self._dataset = self._dataset.load(
1273
- is_run_input=False # we already track this in the Curator constructor
1274
- )
1275
- self._is_validated: bool = False
1276
- # shared until here
1277
- self._categoricals = categoricals or {}
1278
- self._non_validated = None
1279
- self._sources = sources or {}
1280
- self._columns_field = columns_field
1281
- self._validate_category_error_messages: str = ""
1282
- self._cat_columns: dict[str, CatColumn] = {}
1283
-
1284
- @property
1285
- def non_validated(self) -> dict[str, list[str]]:
1286
- """Return the non-validated features and labels."""
1287
- if self._non_validated is None:
1288
- raise ValidationError("Please run validate() first!")
1289
- return {
1290
- key: cat_column._non_validated
1291
- for key, cat_column in self._cat_columns.items()
1292
- if cat_column._non_validated and key != "columns"
1293
- }
1294
-
1295
- @property
1296
- def categoricals(self) -> dict:
1297
- """Return the columns fields to validate against."""
1298
- return self._categoricals
1299
-
1300
- def validate(self) -> bool:
1301
- """Validate dataset.
1302
-
1303
- This method also registers the validated records in the current instance.
1304
-
1305
- Returns:
1306
- The boolean `True` if the dataset is validated. Otherwise, a string with the error message.
1307
- """
1308
- pass # pragma: no cover
1309
-
1310
- def standardize(self, key: str) -> None:
1311
- """Replace synonyms with standardized values.
1312
-
1313
- Inplace modification of the dataset.
1314
-
1315
- Args:
1316
- key: The name of the column to standardize.
1317
-
1318
- Returns:
1319
- None
1320
- """
1321
- pass # pragma: no cover
1322
-
1323
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
1324
- def save_artifact(
1325
- self,
1326
- *,
1327
- key: str | None = None,
1328
- description: str | None = None,
1329
- revises: Artifact | None = None,
1330
- run: Run | None = None,
1331
- ) -> Artifact:
1332
- """{}""" # noqa: D415
1333
- # Make sure all labels are saved in the current instance
1334
- if not self._is_validated:
1335
- self.validate() # returns True or False
1336
- if not self._is_validated: # need to raise error manually
1337
- raise ValidationError("Dataset does not validate. Please curate.")
1338
-
1339
- if self._artifact is None:
1340
- if isinstance(self._dataset, pd.DataFrame):
1341
- artifact = Artifact.from_df(
1342
- self._dataset,
1343
- key=key,
1344
- description=description,
1345
- revises=revises,
1346
- run=run,
1347
- )
1348
- elif isinstance(self._dataset, AnnData):
1349
- artifact = Artifact.from_anndata(
1350
- self._dataset,
1351
- key=key,
1352
- description=description,
1353
- revises=revises,
1354
- run=run,
1355
- )
1356
- elif data_is_mudata(self._dataset):
1357
- artifact = Artifact.from_mudata(
1358
- self._dataset,
1359
- key=key,
1360
- description=description,
1361
- revises=revises,
1362
- run=run,
1363
- )
1364
- elif data_is_spatialdata(self._dataset):
1365
- artifact = Artifact.from_spatialdata(
1366
- self._dataset,
1367
- key=key,
1368
- description=description,
1369
- revises=revises,
1370
- run=run,
1371
- )
1372
- else:
1373
- raise InvalidArgument( # pragma: no cover
1374
- "data must be one of pd.Dataframe, AnnData, MuData, SpatialData."
1375
- )
1376
- self._artifact = artifact.save()
1377
- annotate_artifact( # type: ignore
1378
- self._artifact,
1379
- index_field=self._columns_field,
1380
- cat_columns=self._cat_columns,
1381
- )
1382
- return self._artifact
1383
-
1384
-
1385
- class DataFrameCatManager(CatManager):
1386
- """Categorical manager for `DataFrame`."""
1387
-
1388
- def __init__(
1389
- self,
1390
- df: pd.DataFrame | Artifact,
1391
- columns: FieldAttr = Feature.name,
1392
- categoricals: list[Feature] | dict[str, FieldAttr] | None = None,
1393
- sources: dict[str, Record] | None = None,
1394
- ) -> None:
1395
- self._non_validated = None
1396
- super().__init__(
1397
- dataset=df,
1398
- columns_field=columns,
1399
- categoricals=categoricals,
1400
- sources=sources,
1401
- )
1402
- if columns == Feature.name:
1403
- if isinstance(self._categoricals, list):
1404
- values = [feature.name for feature in self._categoricals]
1405
- else:
1406
- values = list(self._categoricals.keys())
1407
- self._cat_columns["columns"] = CatColumn(
1408
- values_getter=values,
1409
- field=self._columns_field,
1410
- key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
1411
- source=self._sources.get("columns"),
1412
- )
1413
- self._cat_columns["columns"].validate()
1414
- else:
1415
- # NOTE: for var_index right now
1416
- self._cat_columns["columns"] = CatColumn(
1417
- values_getter=lambda: self._dataset.columns, # lambda ensures the inplace update
1418
- values_setter=lambda new_values: setattr(
1419
- self._dataset, "columns", pd.Index(new_values)
1420
- ),
1421
- field=self._columns_field,
1422
- key="columns",
1423
- source=self._sources.get("columns"),
1424
- )
1425
- if isinstance(self._categoricals, list):
1426
- for feature in self._categoricals:
1427
- result = parse_dtype(feature.dtype)[
1428
- 0
1429
- ] # TODO: support composite dtypes for categoricals
1430
- key = feature.name
1431
- field = result["field"]
1432
- self._cat_columns[key] = CatColumn(
1433
- values_getter=lambda k=key: self._dataset[
1434
- k
1435
- ], # Capture key as default argument
1436
- values_setter=lambda new_values, k=key: self._dataset.__setitem__(
1437
- k, new_values
1438
- ),
1439
- field=field,
1440
- key=key,
1441
- source=self._sources.get(key),
1442
- feature=feature,
1443
- )
1444
- else:
1445
- # below is for backward compat of ln.Curator.from_df()
1446
- for key, field in self._categoricals.items():
1447
- self._cat_columns[key] = CatColumn(
1448
- values_getter=lambda k=key: self._dataset[
1449
- k
1450
- ], # Capture key as default argument
1451
- values_setter=lambda new_values, k=key: self._dataset.__setitem__(
1452
- k, new_values
1453
- ),
1454
- field=field,
1455
- key=key,
1456
- source=self._sources.get(key),
1457
- feature=Feature.get(name=key),
1458
- )
1459
-
1460
- def lookup(self, public: bool = False) -> CatLookup:
1461
- """Lookup categories.
1462
-
1463
- Args:
1464
- public: If "public", the lookup is performed on the public reference.
1465
- """
1466
- return CatLookup(
1467
- categoricals=self._categoricals,
1468
- slots={"columns": self._columns_field},
1469
- public=public,
1470
- sources=self._sources,
1471
- )
1472
-
1473
- def validate(self) -> bool:
1474
- """Validate variables and categorical observations."""
1475
- self._validate_category_error_messages = "" # reset the error messages
1476
-
1477
- validated = True
1478
- for _, cat_column in self._cat_columns.items():
1479
- cat_column.validate()
1480
- validated &= cat_column.is_validated
1481
- self._is_validated = validated
1482
- self._non_validated = {} # so it's no longer None
1483
-
1484
- return self._is_validated
1485
-
1486
- def standardize(self, key: str) -> None:
1487
- """Replace synonyms with standardized values.
1488
-
1489
- Modifies the input dataset inplace.
1490
-
1491
- Args:
1492
- key: The key referencing the column in the DataFrame to standardize.
1493
- """
1494
- if self._artifact is not None:
1495
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1496
-
1497
- if key == "all":
1498
- logger.warning(
1499
- "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1500
- )
1501
- for k in self.non_validated.keys():
1502
- self._cat_columns[k].standardize()
1503
- else:
1504
- self._cat_columns[key].standardize()
1505
-
1506
- def add_new_from(self, key: str, **kwargs):
1507
- """Add validated & new categories.
1508
-
1509
- Args:
1510
- key: The key referencing the slot in the DataFrame from which to draw terms.
1511
- **kwargs: Additional keyword arguments to pass to create new records
1512
- """
1513
- if len(kwargs) > 0 and key == "all":
1514
- raise ValueError("Cannot pass additional arguments to 'all' key!")
1515
- if key == "all":
1516
- logger.warning(
1517
- "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1518
- )
1519
- for k in self.non_validated.keys():
1520
- self._cat_columns[k].add_new(**kwargs)
1521
- else:
1522
- self._cat_columns[key].add_new(**kwargs)
1523
-
1524
- @deprecated(
1525
- new_name="Run.filter(transform=context.run.transform, output_artifacts=None)"
1526
- )
1527
- def clean_up_failed_runs(self):
1528
- """Clean up previous failed runs that don't save any outputs."""
1529
- from lamindb.core._context import context
1530
-
1531
- if context.run is not None:
1532
- Run.filter(transform=context.run.transform, output_artifacts=None).exclude(
1533
- uid=context.run.uid
1534
- ).delete()
1535
-
1536
-
1537
- class AnnDataCatManager(CatManager):
1538
- """Categorical manager for `AnnData`."""
1539
-
1540
- def __init__(
1541
- self,
1542
- data: ad.AnnData | Artifact,
1543
- var_index: FieldAttr | None = None,
1544
- categoricals: dict[str, FieldAttr] | None = None,
1545
- obs_columns: FieldAttr = Feature.name,
1546
- sources: dict[str, Record] | None = None,
1547
- ) -> None:
1548
- if isinstance(var_index, str):
1549
- raise TypeError(
1550
- "var_index parameter has to be a field, e.g. Gene.ensembl_gene_id"
1551
- )
1552
-
1553
- if not data_is_anndata(data):
1554
- raise TypeError("data has to be an AnnData object")
1555
-
1556
- if "symbol" in str(var_index):
1557
- logger.warning(
1558
- "indexing datasets with gene symbols can be problematic: https://docs.lamin.ai/faq/symbol-mapping"
1559
- )
1560
-
1561
- self._obs_fields = categoricals or {}
1562
- self._var_field = var_index
1563
- self._sources = sources or {}
1564
- super().__init__(
1565
- dataset=data,
1566
- categoricals=categoricals,
1567
- sources=self._sources,
1568
- columns_field=var_index,
1569
- )
1570
- self._adata = self._dataset
1571
- self._obs_df_curator = DataFrameCatManager(
1572
- df=self._adata.obs,
1573
- categoricals=self.categoricals,
1574
- columns=obs_columns,
1575
- sources=self._sources,
1576
- )
1577
- self._cat_columns = self._obs_df_curator._cat_columns.copy()
1578
- if var_index is not None:
1579
- self._cat_columns["var_index"] = CatColumn(
1580
- values_getter=lambda: self._adata.var.index,
1581
- values_setter=lambda new_values: setattr(
1582
- self._adata.var, "index", pd.Index(new_values)
1583
- ),
1584
- field=self._var_field,
1585
- key="var_index",
1586
- source=self._sources.get("var_index"),
1587
- )
1588
-
1589
- @property
1590
- def var_index(self) -> FieldAttr:
1591
- """Return the registry field to validate variables index against."""
1592
- return self._var_field
1593
-
1594
- @property
1595
- def categoricals(self) -> dict:
1596
- """Return the obs fields to validate against."""
1597
- return self._obs_fields
1598
-
1599
- def lookup(self, public: bool = False) -> CatLookup:
1600
- """Lookup categories.
1601
-
1602
- Args:
1603
- public: If "public", the lookup is performed on the public reference.
1604
- """
1605
- return CatLookup(
1606
- categoricals=self._obs_fields,
1607
- slots={"columns": self._columns_field, "var_index": self._var_field},
1608
- public=public,
1609
- sources=self._sources,
1610
- )
1611
-
1612
- def add_new_from(self, key: str, **kwargs):
1613
- """Add validated & new categories.
1614
-
1615
- Args:
1616
- key: The key referencing the slot in the DataFrame from which to draw terms.
1617
- **kwargs: Additional keyword arguments to pass to create new records
1618
- """
1619
- if key == "all":
1620
- logger.warning(
1621
- "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1622
- )
1623
- for k in self.non_validated.keys():
1624
- self._cat_columns[k].add_new(**kwargs)
1625
- else:
1626
- self._cat_columns[key].add_new(**kwargs)
1627
-
1628
- @deprecated(new_name="add_new_from('var_index')")
1629
- def add_new_from_var_index(self, **kwargs):
1630
- """Update variable records.
1631
-
1632
- Args:
1633
- **kwargs: Additional keyword arguments to pass to create new records.
1634
- """
1635
- self.add_new_from(key="var_index", **kwargs)
1636
-
1637
- def validate(self) -> bool:
1638
- """Validate categories.
1639
-
1640
- This method also registers the validated records in the current instance.
1641
-
1642
- Returns:
1643
- Whether the AnnData object is validated.
1644
- """
1645
- self._validate_category_error_messages = "" # reset the error messages
1646
-
1647
- validated = True
1648
- for _, cat_column in self._cat_columns.items():
1649
- cat_column.validate()
1650
- validated &= cat_column.is_validated
1651
-
1652
- self._non_validated = {} # so it's no longer None
1653
- self._is_validated = validated
1654
- return self._is_validated
1655
-
1656
- def standardize(self, key: str):
1657
- """Replace synonyms with standardized values.
1658
-
1659
- Args:
1660
- key: The key referencing the slot in `adata.obs` from which to draw terms. Same as the key in `categoricals`.
1661
-
1662
- - If "var_index", standardize the var.index.
1663
- - If "all", standardize all obs columns and var.index.
1664
-
1665
- Inplace modification of the dataset.
1666
- """
1667
- if self._artifact is not None:
1668
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1669
- if key == "all":
1670
- logger.warning(
1671
- "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1672
- )
1673
- for k in self.non_validated.keys():
1674
- self._cat_columns[k].standardize()
1675
- else:
1676
- self._cat_columns[key].standardize()
1677
-
1678
-
1679
- @deprecated(new_name="MuDataCurator")
1680
- class MuDataCatManager(CatManager):
1681
- """Categorical manager for `MuData`."""
1682
-
1683
- def __init__(
1684
- self,
1685
- mdata: MuData | Artifact,
1686
- var_index: dict[str, FieldAttr] | None = None,
1687
- categoricals: dict[str, FieldAttr] | None = None,
1688
- sources: dict[str, Record] | None = None,
1689
- ) -> None:
1690
- super().__init__(
1691
- dataset=mdata,
1692
- categoricals={},
1693
- sources=sources,
1694
- )
1695
- self._columns_field = (
1696
- var_index or {}
1697
- ) # this is for consistency with BaseCatManager
1698
- self._var_fields = var_index or {}
1699
- self._verify_modality(self._var_fields.keys())
1700
- self._obs_fields = self._parse_categoricals(categoricals or {})
1701
- self._modalities = set(self._var_fields.keys()) | set(self._obs_fields.keys())
1702
- self._obs_df_curator = None
1703
- if "obs" in self._modalities:
1704
- self._obs_df_curator = DataFrameCatManager(
1705
- df=self._dataset.obs,
1706
- columns=Feature.name,
1707
- categoricals=self._obs_fields.get("obs", {}),
1708
- sources=self._sources.get("obs"),
1709
- )
1710
- self._mod_adata_curators = {
1711
- modality: AnnDataCatManager(
1712
- data=self._dataset[modality],
1713
- var_index=var_index.get(modality),
1714
- categoricals=self._obs_fields.get(modality),
1715
- sources=self._sources.get(modality),
1716
- )
1717
- for modality in self._modalities
1718
- if modality != "obs"
1719
- }
1720
- self._non_validated = None
1721
-
1722
- @property
1723
- def var_index(self) -> FieldAttr:
1724
- """Return the registry field to validate variables index against."""
1725
- return self._var_fields
1726
-
1727
- @property
1728
- def categoricals(self) -> dict:
1729
- """Return the obs fields to validate against."""
1730
- return self._obs_fields
1731
-
1732
- @property
1733
- def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
1734
- """Return the non-validated features and labels."""
1735
- if self._non_validated is None:
1736
- raise ValidationError("Please run validate() first!")
1737
- non_validated = {}
1738
- if (
1739
- self._obs_df_curator is not None
1740
- and len(self._obs_df_curator.non_validated) > 0
1741
- ):
1742
- non_validated["obs"] = self._obs_df_curator.non_validated
1743
- for modality, adata_curator in self._mod_adata_curators.items():
1744
- if len(adata_curator.non_validated) > 0:
1745
- non_validated[modality] = adata_curator.non_validated
1746
- self._non_validated = non_validated
1747
- return self._non_validated
1748
-
1749
- def _verify_modality(self, modalities: Iterable[str]):
1750
- """Verify the modality exists."""
1751
- for modality in modalities:
1752
- if modality not in self._dataset.mod.keys():
1753
- raise ValidationError(f"modality '{modality}' does not exist!")
1754
-
1755
- def _parse_categoricals(self, categoricals: dict[str, FieldAttr]) -> dict:
1756
- """Parse the categorical fields."""
1757
- prefixes = {f"{k}:" for k in self._dataset.mod.keys()}
1758
- obs_fields: dict[str, dict[str, FieldAttr]] = {}
1759
- for k, v in categoricals.items():
1760
- if k not in self._dataset.obs.columns:
1761
- raise ValidationError(f"column '{k}' does not exist in mdata.obs!")
1762
- if any(k.startswith(prefix) for prefix in prefixes):
1763
- modality, col = k.split(":")[0], k.split(":")[1]
1764
- if modality not in obs_fields.keys():
1765
- obs_fields[modality] = {}
1766
- obs_fields[modality][col] = v
1767
- else:
1768
- if "obs" not in obs_fields.keys():
1769
- obs_fields["obs"] = {}
1770
- obs_fields["obs"][k] = v
1771
- return obs_fields
1772
-
1773
- def lookup(self, public: bool = False) -> CatLookup:
1774
- """Lookup categories.
1775
-
1776
- Args:
1777
- public: Perform lookup on public source ontologies.
1778
- """
1779
- obs_fields = {}
1780
- for mod, fields in self._obs_fields.items():
1781
- for k, v in fields.items():
1782
- if k == "obs":
1783
- obs_fields[k] = v
1784
- else:
1785
- obs_fields[f"{mod}:{k}"] = v
1786
- return CatLookup(
1787
- categoricals=obs_fields,
1788
- slots={
1789
- **{f"{k}_var_index": v for k, v in self._var_fields.items()},
1790
- },
1791
- public=public,
1792
- sources=self._sources,
1793
- )
1794
-
1795
- @deprecated(new_name="add_new_from('var_index')")
1796
- def add_new_from_var_index(self, modality: str, **kwargs):
1797
- """Update variable records.
1798
-
1799
- Args:
1800
- modality: The modality name.
1801
- **kwargs: Additional keyword arguments to pass to create new records.
1802
- """
1803
- self._mod_adata_curators[modality].add_new_from(key="var_index", **kwargs)
1804
-
1805
- def add_new_from(
1806
- self,
1807
- key: str,
1808
- modality: str | None = None,
1809
- **kwargs,
1810
- ):
1811
- """Add validated & new categories.
1812
-
1813
- Args:
1814
- key: The key referencing the slot in the DataFrame.
1815
- modality: The modality name.
1816
- **kwargs: Additional keyword arguments to pass to create new records.
1817
- """
1818
- modality = modality or "obs"
1819
- if modality in self._mod_adata_curators:
1820
- adata_curator = self._mod_adata_curators[modality]
1821
- adata_curator.add_new_from(key=key, **kwargs)
1822
- if modality == "obs":
1823
- self._obs_df_curator.add_new_from(key=key, **kwargs)
1824
- if key == "var_index":
1825
- self._mod_adata_curators[modality].add_new_from(key=key, **kwargs)
1826
-
1827
- def validate(self) -> bool:
1828
- """Validate categories."""
1829
- obs_validated = True
1830
- if "obs" in self._modalities:
1831
- logger.info('validating categoricals in "obs"...')
1832
- obs_validated &= self._obs_df_curator.validate()
1833
-
1834
- mods_validated = True
1835
- for modality, adata_curator in self._mod_adata_curators.items():
1836
- logger.info(f'validating categoricals in modality "{modality}"...')
1837
- mods_validated &= adata_curator.validate()
1838
-
1839
- self._non_validated = {} # so it's no longer None
1840
- self._is_validated = obs_validated & mods_validated
1841
- return self._is_validated
1842
-
1843
- def standardize(self, key: str, modality: str | None = None):
1844
- """Replace synonyms with standardized values.
1845
-
1846
- Args:
1847
- key: The key referencing the slot in the `MuData`.
1848
- modality: The modality name.
1849
-
1850
- Inplace modification of the dataset.
1851
- """
1852
- if self._artifact is not None:
1853
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1854
- modality = modality or "obs"
1855
- if modality in self._mod_adata_curators:
1856
- adata_curator = self._mod_adata_curators[modality]
1857
- adata_curator.standardize(key=key)
1858
- if modality == "obs":
1859
- self._obs_df_curator.standardize(key=key)
1860
-
1861
-
1862
- def _maybe_curation_keys_not_present(nonval_keys: list[str], name: str):
1863
- if (n := len(nonval_keys)) > 0:
1864
- s = "s" if n > 1 else ""
1865
- are = "are" if n > 1 else "is"
1866
- raise ValidationError(
1867
- f"key{s} passed to {name} {are} not present: {colors.yellow(_format_values(nonval_keys))}"
1868
- )
1869
-
1870
-
1871
- @deprecated(new_name="SpatialDataCurator")
1872
- class SpatialDataCatManager(CatManager):
1873
- """Categorical manager for `SpatialData`."""
1874
-
1875
- def __init__(
1876
- self,
1877
- sdata: Any,
1878
- var_index: dict[str, FieldAttr],
1879
- categoricals: dict[str, dict[str, FieldAttr]] | None = None,
1880
- sources: dict[str, dict[str, Record]] | None = None,
1881
- *,
1882
- sample_metadata_key: str | None = "sample",
1883
- ) -> None:
1884
- super().__init__(
1885
- dataset=sdata,
1886
- categoricals={},
1887
- sources=sources,
1888
- )
1889
- if isinstance(sdata, Artifact):
1890
- self._sdata = sdata.load()
1891
- else:
1892
- self._sdata = self._dataset
1893
- self._sample_metadata_key = sample_metadata_key
1894
- self._write_path = None
1895
- self._var_fields = var_index
1896
- self._verify_accessor_exists(self._var_fields.keys())
1897
- self._categoricals = categoricals
1898
- self._table_keys = set(self._var_fields.keys()) | set(
1899
- self._categoricals.keys() - {self._sample_metadata_key}
1900
- )
1901
- self._sample_df_curator = None
1902
- if self._sample_metadata_key is not None:
1903
- self._sample_metadata = self._sdata.get_attrs(
1904
- key=self._sample_metadata_key, return_as="df", flatten=True
1905
- )
1906
- self._is_validated = False
1907
-
1908
- # Check validity of keys in categoricals
1909
- nonval_keys = []
1910
- for accessor, accessor_categoricals in self._categoricals.items():
1911
- if (
1912
- accessor == self._sample_metadata_key
1913
- and self._sample_metadata is not None
1914
- ):
1915
- for key in accessor_categoricals.keys():
1916
- if key not in self._sample_metadata.columns:
1917
- nonval_keys.append(key)
1918
- else:
1919
- for key in accessor_categoricals.keys():
1920
- if key not in self._sdata[accessor].obs.columns:
1921
- nonval_keys.append(key)
1922
-
1923
- _maybe_curation_keys_not_present(nonval_keys, "categoricals")
1924
-
1925
- # check validity of keys in sources
1926
- nonval_keys = []
1927
- for accessor, accessor_sources in self._sources.items():
1928
- if (
1929
- accessor == self._sample_metadata_key
1930
- and self._sample_metadata is not None
1931
- ):
1932
- columns = self._sample_metadata.columns
1933
- elif accessor != self._sample_metadata_key:
1934
- columns = self._sdata[accessor].obs.columns
1935
- else:
1936
- continue
1937
- for key in accessor_sources:
1938
- if key not in columns:
1939
- nonval_keys.append(key)
1940
- _maybe_curation_keys_not_present(nonval_keys, "sources")
1941
-
1942
- # Set up sample level metadata and table Curator objects
1943
- if (
1944
- self._sample_metadata_key is not None
1945
- and self._sample_metadata_key in self._categoricals
1946
- ):
1947
- self._sample_df_curator = DataFrameCatManager(
1948
- df=self._sample_metadata,
1949
- columns=Feature.name,
1950
- categoricals=self._categoricals.get(self._sample_metadata_key, {}),
1951
- sources=self._sources.get(self._sample_metadata_key),
1952
- )
1953
- self._table_adata_curators = {
1954
- table: AnnDataCatManager(
1955
- data=self._sdata[table],
1956
- var_index=var_index.get(table),
1957
- categoricals=self._categoricals.get(table),
1958
- sources=self._sources.get(table),
1959
- )
1960
- for table in self._table_keys
1961
- }
1962
-
1963
- self._non_validated = None
1964
-
1965
- @property
1966
- def var_index(self) -> FieldAttr:
1967
- """Return the registry fields to validate variables indices against."""
1968
- return self._var_fields
1969
-
1970
- @property
1971
- def categoricals(self) -> dict[str, dict[str, FieldAttr]]:
1972
- """Return the categorical keys and fields to validate against."""
1973
- return self._categoricals
1974
-
1975
- @property
1976
- def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
1977
- """Return the non-validated features and labels."""
1978
- if self._non_validated is None:
1979
- raise ValidationError("Please run validate() first!")
1980
- non_curated = {}
1981
- if len(self._sample_df_curator.non_validated) > 0:
1982
- non_curated[self._sample_metadata_key] = (
1983
- self._sample_df_curator.non_validated
1984
- )
1985
- for table, adata_curator in self._table_adata_curators.items():
1986
- if len(adata_curator.non_validated) > 0:
1987
- non_curated[table] = adata_curator.non_validated
1988
- return non_curated
1989
-
1990
- def _verify_accessor_exists(self, accessors: Iterable[str]) -> None:
1991
- """Verify that the accessors exist (either a valid table or in attrs)."""
1992
- for acc in accessors:
1993
- is_present = False
1994
- try:
1995
- self._sdata.get_attrs(key=acc)
1996
- is_present = True
1997
- except KeyError:
1998
- if acc in self._sdata.tables.keys():
1999
- is_present = True
2000
- if not is_present:
2001
- raise ValidationError(f"Accessor '{acc}' does not exist!")
2002
-
2003
- def lookup(self, public: bool = False) -> CatLookup:
2004
- """Look up categories.
2005
-
2006
- Args:
2007
- public: Whether the lookup is performed on the public reference.
2008
- """
2009
- cat_values_dict = list(self.categoricals.values())[0]
2010
- return CatLookup(
2011
- categoricals=cat_values_dict,
2012
- slots={"accessors": cat_values_dict.keys()},
2013
- public=public,
2014
- sources=self._sources,
2015
- )
2016
-
2017
- @deprecated(new_name="add_new_from('var_index')")
2018
- def add_new_from_var_index(self, table: str, **kwargs) -> None:
2019
- """Save new values from ``.var.index`` of table.
2020
-
2021
- Args:
2022
- table: The table key.
2023
- **kwargs: Additional keyword arguments to pass to create new records.
2024
- """
2025
- if table in self.non_validated.keys():
2026
- self._table_adata_curators[table].add_new_from(key="var_index", **kwargs)
2027
-
2028
- def add_new_from(
2029
- self,
2030
- key: str,
2031
- accessor: str | None = None,
2032
- **kwargs,
2033
- ) -> None:
2034
- """Save new values of categorical from sample level metadata or table.
2035
-
2036
- Args:
2037
- key: The key referencing the slot in the DataFrame.
2038
- accessor: The accessor key such as 'sample' or 'table x'.
2039
- **kwargs: Additional keyword arguments to pass to create new records.
2040
- """
2041
- if accessor in self.non_validated.keys():
2042
- if accessor in self._table_adata_curators:
2043
- adata_curator = self._table_adata_curators[accessor]
2044
- adata_curator.add_new_from(key=key, **kwargs)
2045
- if accessor == self._sample_metadata_key:
2046
- self._sample_df_curator.add_new_from(key=key, **kwargs)
2047
-
2048
- if key == "var_index":
2049
- self._table_adata_curators[accessor].add_new_from(key=key, **kwargs)
2050
-
2051
- def standardize(self, key: str, accessor: str | None = None) -> None:
2052
- """Replace synonyms with canonical values.
2053
-
2054
- Modifies the dataset inplace.
2055
-
2056
- Args:
2057
- key: The key referencing the slot in the table or sample metadata.
2058
- accessor: The accessor key such as 'sample_key' or 'table_key'.
2059
- """
2060
- if len(self.non_validated) == 0:
2061
- logger.warning("values are already standardized")
2062
- return
2063
- if self._artifact is not None:
2064
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
2065
-
2066
- if accessor == self._sample_metadata_key:
2067
- if key not in self._sample_metadata.columns:
2068
- raise ValueError(f"key '{key}' not present in '{accessor}'!")
2069
- else:
2070
- if (
2071
- key == "var_index" and self._sdata.tables[accessor].var.index is None
2072
- ) or (
2073
- key != "var_index"
2074
- and key not in self._sdata.tables[accessor].obs.columns
2075
- ):
2076
- raise ValueError(f"key '{key}' not present in '{accessor}'!")
2077
-
2078
- if accessor in self._table_adata_curators.keys():
2079
- adata_curator = self._table_adata_curators[accessor]
2080
- adata_curator.standardize(key)
2081
- if accessor == self._sample_metadata_key:
2082
- self._sample_df_curator.standardize(key)
2083
-
2084
- def validate(self) -> bool:
2085
- """Validate variables and categorical observations.
2086
-
2087
- This method also registers the validated records in the current instance:
2088
- - from public sources
2089
-
2090
- Returns:
2091
- Whether the SpatialData object is validated.
2092
- """
2093
- # add all validated records to the current instance
2094
- sample_validated = True
2095
- if self._sample_df_curator:
2096
- logger.info(f"validating categoricals of '{self._sample_metadata_key}' ...")
2097
- sample_validated &= self._sample_df_curator.validate()
2098
-
2099
- mods_validated = True
2100
- for table, adata_curator in self._table_adata_curators.items():
2101
- logger.info(f"validating categoricals of table '{table}' ...")
2102
- mods_validated &= adata_curator.validate()
2103
-
2104
- self._non_validated = {} # so it's no longer None
2105
- self._is_validated = sample_validated & mods_validated
2106
- return self._is_validated
2107
-
2108
- def save_artifact(
2109
- self,
2110
- *,
2111
- key: str | None = None,
2112
- description: str | None = None,
2113
- revises: Artifact | None = None,
2114
- run: Run | None = None,
2115
- ) -> Artifact:
2116
- """Save the validated SpatialData store and metadata.
2117
-
2118
- Args:
2119
- description: A description of the dataset.
2120
- key: A path-like key to reference artifact in default storage,
2121
- e.g., `"myartifact.zarr"`. Artifacts with the same key form a version family.
2122
- revises: Previous version of the artifact. Triggers a revision.
2123
- run: The run that creates the artifact.
2124
-
2125
- Returns:
2126
- A saved artifact record.
2127
- """
2128
- if not self._is_validated:
2129
- self.validate()
2130
- if not self._is_validated:
2131
- raise ValidationError("Dataset does not validate. Please curate.")
2132
-
2133
- self._artifact = Artifact.from_spatialdata(
2134
- self._dataset, key=key, description=description, revises=revises, run=run
2135
- ).save()
2136
- return annotate_artifact(
2137
- self._artifact,
2138
- index_field=self.var_index,
2139
- sample_metadata_key=self._sample_metadata_key,
2140
- )
2141
-
2142
-
2143
- class TiledbsomaCatManager(CatManager):
2144
- """Categorical manager for `tiledbsoma.Experiment`."""
2145
-
2146
- def __init__(
2147
- self,
2148
- experiment_uri: UPathStr | Artifact,
2149
- var_index: dict[str, tuple[str, FieldAttr]],
2150
- categoricals: dict[str, FieldAttr] | None = None,
2151
- obs_columns: FieldAttr = Feature.name,
2152
- sources: dict[str, Record] | None = None,
2153
- ):
2154
- self._obs_fields = categoricals or {}
2155
- self._var_fields = var_index
2156
- self._columns_field = obs_columns
2157
- if isinstance(experiment_uri, Artifact):
2158
- self._dataset = experiment_uri.path
2159
- self._artifact = experiment_uri
2160
- else:
2161
- self._dataset = UPath(experiment_uri)
2162
- self._artifact = None
2163
- self._sources = sources or {}
2164
-
2165
- self._is_validated: bool | None = False
2166
- self._non_validated_values: dict[str, list] | None = None
2167
- self._validated_values: dict[str, list] = {}
2168
- # filled by _check_save_keys
2169
- self._n_obs: int | None = None
2170
- self._valid_obs_keys: list[str] | None = None
2171
- self._obs_pa_schema: pa.lib.Schema | None = (
2172
- None # this is needed to create the obs feature set
2173
- )
2174
- self._valid_var_keys: list[str] | None = None
2175
- self._var_fields_flat: dict[str, FieldAttr] | None = None
2176
- self._check_save_keys()
2177
-
2178
- # check that the provided keys in var_index and categoricals are available in the store
2179
- # and save features
2180
- def _check_save_keys(self):
2181
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
2182
-
2183
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
2184
- experiment_obs = experiment.obs
2185
- self._n_obs = len(experiment_obs)
2186
- self._obs_pa_schema = experiment_obs.schema
2187
- valid_obs_keys = [
2188
- k for k in self._obs_pa_schema.names if k != "soma_joinid"
2189
- ]
2190
- self._valid_obs_keys = valid_obs_keys
2191
-
2192
- valid_var_keys = []
2193
- ms_list = []
2194
- for ms in experiment.ms.keys():
2195
- ms_list.append(ms)
2196
- var_ms = experiment.ms[ms].var
2197
- valid_var_keys += [
2198
- f"{ms}__{k}" for k in var_ms.keys() if k != "soma_joinid"
2199
- ]
2200
- self._valid_var_keys = valid_var_keys
2201
-
2202
- # check validity of keys in categoricals
2203
- nonval_keys = []
2204
- for obs_key in self._obs_fields.keys():
2205
- if obs_key not in valid_obs_keys:
2206
- nonval_keys.append(obs_key)
2207
- _maybe_curation_keys_not_present(nonval_keys, "categoricals")
2208
-
2209
- # check validity of keys in var_index
2210
- self._var_fields_flat = {}
2211
- nonval_keys = []
2212
- for ms_key in self._var_fields.keys():
2213
- var_key, var_field = self._var_fields[ms_key]
2214
- var_key_flat = f"{ms_key}__{var_key}"
2215
- if var_key_flat not in valid_var_keys:
2216
- nonval_keys.append(f"({ms_key}, {var_key})")
2217
- else:
2218
- self._var_fields_flat[var_key_flat] = var_field
2219
- _maybe_curation_keys_not_present(nonval_keys, "var_index")
2220
-
2221
- # check validity of keys in sources
2222
- valid_arg_keys = valid_obs_keys + valid_var_keys + ["columns"]
2223
- nonval_keys = []
2224
- for arg_key in self._sources.keys():
2225
- if arg_key not in valid_arg_keys:
2226
- nonval_keys.append(arg_key)
2227
- _maybe_curation_keys_not_present(nonval_keys, "sources")
2228
-
2229
- # register obs columns' names
2230
- register_columns = list(self._obs_fields.keys())
2231
- # register categorical keys as features
2232
- cat_column = CatColumn(
2233
- values_getter=register_columns,
2234
- field=self._columns_field,
2235
- key="columns",
2236
- source=self._sources.get("columns"),
2237
- )
2238
- cat_column.add_new()
2239
-
2240
- def validate(self):
2241
- """Validate categories."""
2242
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
2243
-
2244
- validated = True
2245
- self._non_validated_values = {}
2246
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
2247
- for ms, (key, field) in self._var_fields.items():
2248
- var_ms = experiment.ms[ms].var
2249
- var_ms_key = f"{ms}__{key}"
2250
- # it was already validated and cached
2251
- if var_ms_key in self._validated_values:
2252
- continue
2253
- var_ms_values = (
2254
- var_ms.read(column_names=[key]).concat()[key].to_pylist()
2255
- )
2256
- cat_column = CatColumn(
2257
- values_getter=var_ms_values,
2258
- field=field,
2259
- key=var_ms_key,
2260
- source=self._sources.get(var_ms_key),
2261
- )
2262
- cat_column.validate()
2263
- non_val = cat_column._non_validated
2264
- if len(non_val) > 0:
2265
- validated = False
2266
- self._non_validated_values[var_ms_key] = non_val
2267
- else:
2268
- self._validated_values[var_ms_key] = var_ms_values
2269
-
2270
- obs = experiment.obs
2271
- for key, field in self._obs_fields.items():
2272
- # already validated and cached
2273
- if key in self._validated_values:
2274
- continue
2275
- values = pa.compute.unique(
2276
- obs.read(column_names=[key]).concat()[key]
2277
- ).to_pylist()
2278
- cat_column = CatColumn(
2279
- values_getter=values,
2280
- field=field,
2281
- key=key,
2282
- source=self._sources.get(key),
2283
- )
2284
- cat_column.validate()
2285
- non_val = cat_column._non_validated
2286
- if len(non_val) > 0:
2287
- validated = False
2288
- self._non_validated_values[key] = non_val
2289
- else:
2290
- self._validated_values[key] = values
2291
- self._is_validated = validated
2292
- return self._is_validated
2293
-
2294
- def _non_validated_values_field(self, key: str) -> tuple[list, FieldAttr]:
2295
- assert self._non_validated_values is not None # noqa: S101
2296
-
2297
- if key in self._valid_obs_keys:
2298
- field = self._obs_fields[key]
2299
- elif key in self._valid_var_keys:
2300
- ms = key.partition("__")[0]
2301
- field = self._var_fields[ms][1]
2302
- else:
2303
- raise KeyError(f"key {key} is invalid!")
2304
- values = self._non_validated_values.get(key, [])
2305
- return values, field
2306
-
2307
- def add_new_from(self, key: str, **kwargs) -> None:
2308
- """Add validated & new categories.
2309
-
2310
- Args:
2311
- key: The key referencing the slot in the `tiledbsoma` store.
2312
- It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
2313
- or a column name in `.obs`.
2314
- """
2315
- if self._non_validated_values is None:
2316
- raise ValidationError("Run .validate() first.")
2317
- if key == "all":
2318
- keys = list(self._non_validated_values.keys())
2319
- else:
2320
- avail_keys = list(
2321
- chain(self._non_validated_values.keys(), self._validated_values.keys())
2322
- )
2323
- if key not in avail_keys:
2324
- raise KeyError(
2325
- f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
2326
- )
2327
- keys = [key]
2328
- for k in keys:
2329
- values, field = self._non_validated_values_field(k)
2330
- if len(values) == 0:
2331
- continue
2332
- cat_column = CatColumn(
2333
- values_getter=values,
2334
- field=field,
2335
- key=k,
2336
- source=self._sources.get(k),
2337
- )
2338
- cat_column.add_new()
2339
- # update non-validated values list but keep the key there
2340
- # it will be removed by .validate()
2341
- if k in self._non_validated_values:
2342
- self._non_validated_values[k] = []
2343
-
2344
- @property
2345
- def non_validated(self) -> dict[str, list]:
2346
- """Return the non-validated features and labels."""
2347
- non_val = {k: v for k, v in self._non_validated_values.items() if v != []}
2348
- return non_val
2349
-
2350
- @property
2351
- def var_index(self) -> dict[str, FieldAttr]:
2352
- """Return the registry fields with flattened keys to validate variables indices against."""
2353
- return self._var_fields_flat
2354
-
2355
- @property
2356
- def categoricals(self) -> dict[str, FieldAttr]:
2357
- """Return the obs fields to validate against."""
2358
- return self._obs_fields
2359
-
2360
- def lookup(self, public: bool = False) -> CatLookup:
2361
- """Lookup categories.
2362
-
2363
- Args:
2364
- public: If "public", the lookup is performed on the public reference.
2365
- """
2366
- return CatLookup(
2367
- categoricals=self._obs_fields,
2368
- slots={"columns": self._columns_field, **self._var_fields_flat},
2369
- public=public,
2370
- sources=self._sources,
2371
- )
2372
-
2373
- def standardize(self, key: str):
2374
- """Replace synonyms with standardized values.
2375
-
2376
- Modifies the dataset inplace.
2377
-
2378
- Args:
2379
- key: The key referencing the slot in the `tiledbsoma` store.
2380
- It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
2381
- or a column name in `.obs`.
2382
- """
2383
- if len(self.non_validated) == 0:
2384
- logger.warning("values are already standardized")
2385
- return
2386
- avail_keys = list(self._non_validated_values.keys())
2387
- if key == "all":
2388
- keys = avail_keys
2389
- else:
2390
- if key not in avail_keys:
2391
- raise KeyError(
2392
- f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
2393
- )
2394
- keys = [key]
2395
-
2396
- for k in keys:
2397
- values, field = self._non_validated_values_field(k)
2398
- if len(values) == 0:
2399
- continue
2400
- if k in self._valid_var_keys:
2401
- ms, _, slot_key = k.partition("__")
2402
- slot = lambda experiment: experiment.ms[ms].var # noqa: B023
2403
- else:
2404
- slot = lambda experiment: experiment.obs
2405
- slot_key = k
2406
- cat_column = CatColumn(
2407
- values_getter=values,
2408
- field=field,
2409
- key=k,
2410
- source=self._sources.get(k),
2411
- )
2412
- cat_column.validate()
2413
- syn_mapper = cat_column._synonyms
2414
- if (n_syn_mapper := len(syn_mapper)) == 0:
2415
- continue
2416
-
2417
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
2418
-
2419
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
2420
- value_filter = f"{slot_key} in {list(syn_mapper.keys())}"
2421
- table = slot(experiment).read(value_filter=value_filter).concat()
2422
-
2423
- if len(table) == 0:
2424
- continue
2425
-
2426
- df = table.to_pandas()
2427
- # map values
2428
- df[slot_key] = df[slot_key].map(
2429
- lambda val: syn_mapper.get(val, val) # noqa
2430
- )
2431
- # write the mapped values
2432
- with _open_tiledbsoma(self._dataset, mode="w") as experiment:
2433
- slot(experiment).write(pa.Table.from_pandas(df, schema=table.schema))
2434
- # update non_validated dict
2435
- non_val_k = [
2436
- nv for nv in self._non_validated_values[k] if nv not in syn_mapper
2437
- ]
2438
- self._non_validated_values[k] = non_val_k
2439
-
2440
- syn_mapper_print = _format_values(
2441
- [f'"{m_k}" → "{m_v}"' for m_k, m_v in syn_mapper.items()], sep=""
2442
- )
2443
- s = "s" if n_syn_mapper > 1 else ""
2444
- logger.success(
2445
- f'standardized {n_syn_mapper} synonym{s} in "{k}": {colors.green(syn_mapper_print)}'
2446
- )
2447
-
2448
- def save_artifact(
2449
- self,
2450
- *,
2451
- key: str | None = None,
2452
- description: str | None = None,
2453
- revises: Artifact | None = None,
2454
- run: Run | None = None,
2455
- ) -> Artifact:
2456
- """Save the validated `tiledbsoma` store and metadata.
2457
-
2458
- Args:
2459
- description: A description of the ``tiledbsoma`` store.
2460
- key: A path-like key to reference artifact in default storage,
2461
- e.g., `"myfolder/mystore.tiledbsoma"`. Artifacts with the same key form a version family.
2462
- revises: Previous version of the artifact. Triggers a revision.
2463
- run: The run that creates the artifact.
2464
-
2465
- Returns:
2466
- A saved artifact record.
2467
- """
2468
- if not self._is_validated:
2469
- self.validate()
2470
- if not self._is_validated:
2471
- raise ValidationError("Dataset does not validate. Please curate.")
2472
-
2473
- if self._artifact is None:
2474
- artifact = Artifact(
2475
- self._dataset,
2476
- description=description,
2477
- key=key,
2478
- revises=revises,
2479
- run=run,
2480
- )
2481
- artifact.n_observations = self._n_obs
2482
- artifact.otype = "tiledbsoma"
2483
- artifact.save()
2484
- else:
2485
- artifact = self._artifact
2486
-
2487
- feature_sets = {}
2488
- if len(self._obs_fields) > 0:
2489
- empty_dict = {field.name: [] for field in self._obs_pa_schema} # type: ignore
2490
- mock_df = pa.Table.from_pydict(
2491
- empty_dict, schema=self._obs_pa_schema
2492
- ).to_pandas()
2493
- # in parallel to https://github.com/laminlabs/lamindb/blob/2a1709990b5736b480c6de49c0ada47fafc8b18d/lamindb/core/_feature_manager.py#L549-L554
2494
- feature_sets["obs"] = Schema.from_df(
2495
- df=mock_df,
2496
- field=self._columns_field,
2497
- mute=True,
2498
- )
2499
- for ms in self._var_fields:
2500
- var_key, var_field = self._var_fields[ms]
2501
- feature_sets[f"{ms}__var"] = Schema.from_values(
2502
- values=self._validated_values[f"{ms}__{var_key}"],
2503
- field=var_field,
2504
- raise_validation_error=False,
2505
- )
2506
- artifact._staged_feature_sets = feature_sets
2507
-
2508
- feature_ref_is_name = _ref_is_name(self._columns_field)
2509
- features = Feature.lookup().dict()
2510
- for key, field in self._obs_fields.items():
2511
- feature = features.get(key)
2512
- registry = field.field.model
2513
- labels = registry.from_values(
2514
- values=self._validated_values[key],
2515
- field=field,
2516
- )
2517
- if len(labels) == 0:
2518
- continue
2519
- if hasattr(registry, "_name_field"):
2520
- label_ref_is_name = field.field.name == registry._name_field
2521
- add_labels(
2522
- artifact,
2523
- records=labels,
2524
- feature=feature,
2525
- feature_ref_is_name=feature_ref_is_name,
2526
- label_ref_is_name=label_ref_is_name,
2527
- from_curator=True,
2528
- )
2529
-
2530
- return artifact.save()
2531
-
2532
-
2533
- class CellxGeneAnnDataCatManager(AnnDataCatManager):
2534
- """Categorical manager for `AnnData` respecting the CELLxGENE schema.
2535
-
2536
- This will be superceded by a schema-based curation flow.
2537
- """
2538
-
2539
- cxg_categoricals_defaults = {
2540
- "cell_type": "unknown",
2541
- "development_stage": "unknown",
2542
- "disease": "normal",
2543
- "donor_id": "unknown",
2544
- "self_reported_ethnicity": "unknown",
2545
- "sex": "unknown",
2546
- "suspension_type": "cell",
2547
- "tissue_type": "tissue",
2548
- }
2549
-
2550
- def __init__(
2551
- self,
2552
- adata: ad.AnnData,
2553
- categoricals: dict[str, FieldAttr] | None = None,
2554
- *,
2555
- schema_version: Literal["4.0.0", "5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
2556
- defaults: dict[str, str] = None,
2557
- extra_sources: dict[str, Record] = None,
2558
- ) -> None:
2559
- """CELLxGENE schema curator.
2560
-
2561
- Args:
2562
- adata: Path to or AnnData object to curate against the CELLxGENE schema.
2563
- categoricals: A dictionary mapping ``.obs.columns`` to a registry field.
2564
- The CELLxGENE Curator maps against the required CELLxGENE fields by default.
2565
- schema_version: The CELLxGENE schema version to curate against.
2566
- defaults: Default values that are set if columns or column values are missing.
2567
- extra_sources: A dictionary mapping ``.obs.columns`` to Source records.
2568
- These extra sources are joined with the CELLxGENE fixed sources.
2569
- Use this parameter when subclassing.
2570
- """
2571
- import bionty as bt
2572
-
2573
- from ._cellxgene_schemas import (
2574
- _add_defaults_to_obs,
2575
- _create_sources,
2576
- _init_categoricals_additional_values,
2577
- _restrict_obs_fields,
2578
- )
2579
-
2580
- # Add defaults first to ensure that we fetch valid sources
2581
- if defaults:
2582
- _add_defaults_to_obs(adata.obs, defaults)
2583
-
2584
- # Filter categoricals based on what's present in adata
2585
- if categoricals is None:
2586
- categoricals = self._get_cxg_categoricals()
2587
- categoricals = _restrict_obs_fields(adata.obs, categoricals)
2588
-
2589
- # Configure sources
2590
- organism: Literal["human", "mouse"] = "human"
2591
- sources = _create_sources(categoricals, schema_version, organism)
2592
- self.schema_version = schema_version
2593
- self.schema_reference = f"https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/{schema_version}/schema.md"
2594
- # These sources are not a part of the cellxgene schema but rather passed through.
2595
- # This is useful when other Curators extend the CELLxGENE curator
2596
- if extra_sources:
2597
- sources = sources | extra_sources
2598
-
2599
- _init_categoricals_additional_values()
2600
-
2601
- super().__init__(
2602
- data=adata,
2603
- var_index=bt.Gene.ensembl_gene_id,
2604
- categoricals=categoricals,
2605
- sources=sources,
2606
- )
2607
-
2608
- @classmethod
2609
- def _get_cxg_categoricals(cls) -> dict[str, FieldAttr]:
2610
- """Returns the CELLxGENE schema mapped fields."""
2611
- from ._cellxgene_schemas import _get_cxg_categoricals
2612
-
2613
- return _get_cxg_categoricals()
2614
-
2615
- def validate(self) -> bool:
2616
- """Validates the AnnData object against most cellxgene requirements."""
2617
- from ._cellxgene_schemas import RESERVED_NAMES
2618
-
2619
- # Verify that all required obs columns are present
2620
- required_columns = list(self.cxg_categoricals_defaults.keys()) + ["donor_id"]
2621
- missing_obs_fields = [
2622
- name
2623
- for name in required_columns
2624
- if name not in self._adata.obs.columns
2625
- and f"{name}_ontology_term_id" not in self._adata.obs.columns
2626
- ]
2627
- if len(missing_obs_fields) > 0:
2628
- logger.error(
2629
- f"missing required obs columns {_format_values(missing_obs_fields)}\n"
2630
- " → consider initializing a Curate object with `defaults=cxg.CellxGeneAnnDataCatManager.cxg_categoricals_defaults` to automatically add these columns with default values"
2631
- )
2632
- return False
2633
-
2634
- # Verify that no cellxgene reserved names are present
2635
- matched_columns = [
2636
- column for column in self._adata.obs.columns if column in RESERVED_NAMES
2637
- ]
2638
- if len(matched_columns) > 0:
2639
- raise ValueError(
2640
- f"AnnData object must not contain obs columns {matched_columns} which are"
2641
- " reserved from previous schema versions."
2642
- )
2643
-
2644
- return super().validate()
2645
-
2646
- def to_cellxgene_anndata(
2647
- self, is_primary_data: bool, title: str | None = None
2648
- ) -> ad.AnnData:
2649
- """Converts the AnnData object to the cellxgene-schema input format.
2650
-
2651
- cellxgene expects the obs fields to be {entity}_ontology_id fields and has many further requirements which are
2652
- documented here: https://github.com/chanzuckerberg/single-cell-curation/tree/main/schema.
2653
- This function checks for most but not all requirements of the CELLxGENE schema.
2654
- If you want to ensure that it fully adheres to the CELLxGENE schema, run `cellxgene-schema` on the AnnData object.
2655
-
2656
- Args:
2657
- is_primary_data: Whether the measured data is primary data or not.
2658
- title: Title of the AnnData object. Commonly the name of the publication.
2659
-
2660
- Returns:
2661
- An AnnData object which adheres to the cellxgene-schema.
2662
- """
2663
-
2664
- def _convert_name_to_ontology_id(values: pd.Series, field: FieldAttr):
2665
- """Converts a column that stores a name into a column that stores the ontology id.
2666
-
2667
- cellxgene expects the obs columns to be {entity}_ontology_id columns and disallows {entity} columns.
2668
- """
2669
- field_name = field.field.name
2670
- assert field_name == "name" # noqa: S101
2671
- cols = ["name", "ontology_id"]
2672
- registry = field.field.model
2673
-
2674
- if hasattr(registry, "ontology_id"):
2675
- validated_records = registry.filter(**{f"{field_name}__in": values})
2676
- mapper = (
2677
- pd.DataFrame(validated_records.values_list(*cols))
2678
- .set_index(0)
2679
- .to_dict()[1]
2680
- )
2681
- return values.map(mapper)
2682
-
2683
- # Create a copy since we modify the AnnData object extensively
2684
- adata_cxg = self._adata.copy()
2685
-
2686
- # cellxgene requires an embedding
2687
- embedding_pattern = r"^[a-zA-Z][a-zA-Z0-9_.-]*$"
2688
- exclude_key = "spatial"
2689
- matching_keys = [
2690
- key
2691
- for key in adata_cxg.obsm.keys()
2692
- if re.match(embedding_pattern, key) and key != exclude_key
2693
- ]
2694
- if len(matching_keys) == 0:
2695
- raise ValueError(
2696
- "Unable to find an embedding key. Please calculate an embedding."
2697
- )
2698
-
2699
- # convert name column to ontology_term_id column
2700
- for column in adata_cxg.obs.columns:
2701
- if column in self.categoricals and not column.endswith("_ontology_term_id"):
2702
- mapped_column = _convert_name_to_ontology_id(
2703
- adata_cxg.obs[column], field=self.categoricals.get(column)
2704
- )
2705
- if mapped_column is not None:
2706
- adata_cxg.obs[f"{column}_ontology_term_id"] = mapped_column
2707
-
2708
- # drop the name columns for ontologies. cellxgene does not allow them.
2709
- drop_columns = [
2710
- i
2711
- for i in adata_cxg.obs.columns
2712
- if f"{i}_ontology_term_id" in adata_cxg.obs.columns
2713
- ]
2714
- adata_cxg.obs.drop(columns=drop_columns, inplace=True)
2715
-
2716
- # Add cellxgene metadata to AnnData object
2717
- if "is_primary_data" not in adata_cxg.obs.columns:
2718
- adata_cxg.obs["is_primary_data"] = is_primary_data
2719
- if "feature_is_filtered" not in adata_cxg.var.columns:
2720
- logger.warn(
2721
- "column 'feature_is_filtered' not present in var. Setting to default"
2722
- " value of False."
2723
- )
2724
- adata_cxg.var["feature_is_filtered"] = False
2725
- if title is None:
2726
- raise ValueError("please pass a title!")
2727
- else:
2728
- adata_cxg.uns["title"] = title
2729
- adata_cxg.uns["cxg_lamin_schema_reference"] = self.schema_reference
2730
- adata_cxg.uns["cxg_lamin_schema_version"] = self.schema_version
2731
-
2732
- return adata_cxg
2733
-
2734
-
2735
- class ValueUnit:
2736
- """Base class for handling value-unit combinations."""
2737
-
2738
- @staticmethod
2739
- def parse_value_unit(value: str, is_dose: bool = True) -> tuple[str, str] | None:
2740
- """Parse a string containing a value and unit into a tuple."""
2741
- if not isinstance(value, str) or not value.strip():
2742
- return None
2743
-
2744
- value = str(value).strip()
2745
- match = re.match(r"^(\d*\.?\d{0,1})\s*([a-zA-ZμµΜ]+)$", value)
2746
-
2747
- if not match:
2748
- raise ValueError(
2749
- f"Invalid format: {value}. Expected format: number with max 1 decimal place + unit"
2750
- )
2751
-
2752
- number, unit = match.groups()
2753
- formatted_number = f"{float(number):.1f}"
2754
-
2755
- if is_dose:
2756
- standardized_unit = DoseHandler.standardize_unit(unit)
2757
- if not DoseHandler.validate_unit(standardized_unit):
2758
- raise ValueError(
2759
- f"Invalid dose unit: {unit}. Must be convertible to one of: nM, μM, mM, M"
2760
- )
2761
- else:
2762
- standardized_unit = TimeHandler.standardize_unit(unit)
2763
- if not TimeHandler.validate_unit(standardized_unit):
2764
- raise ValueError(
2765
- f"Invalid time unit: {unit}. Must be convertible to one of: h, m, s, d, y"
2766
- )
2767
-
2768
- return formatted_number, standardized_unit
2769
-
2770
-
2771
- class DoseHandler:
2772
- """Handler for dose-related operations."""
2773
-
2774
- VALID_UNITS = {"nM", "μM", "µM", "mM", "M"}
2775
- UNIT_MAP = {
2776
- "nm": "nM",
2777
- "NM": "nM",
2778
- "um": "μM",
2779
- "UM": "μM",
2780
- "μm": "μM",
2781
- "μM": "μM",
2782
- "µm": "μM",
2783
- "µM": "μM",
2784
- "mm": "mM",
2785
- "MM": "mM",
2786
- "m": "M",
2787
- "M": "M",
2788
- }
2789
-
2790
- @classmethod
2791
- def validate_unit(cls, unit: str) -> bool:
2792
- """Validate if the dose unit is acceptable."""
2793
- return unit in cls.VALID_UNITS
2794
-
2795
- @classmethod
2796
- def standardize_unit(cls, unit: str) -> str:
2797
- """Standardize dose unit to standard formats."""
2798
- return cls.UNIT_MAP.get(unit, unit)
2799
-
2800
- @classmethod
2801
- def validate_values(cls, values: pd.Series) -> list[str]:
2802
- """Validate pert_dose values with strict case checking."""
2803
- errors = []
2804
-
2805
- for idx, value in values.items():
2806
- if pd.isna(value):
2807
- continue
2808
-
2809
- if isinstance(value, (int, float)):
2810
- errors.append(
2811
- f"Row {idx} - Missing unit for dose: {value}. Must include a unit (nM, μM, mM, M)"
2812
- )
2813
- continue
2814
-
2815
- try:
2816
- ValueUnit.parse_value_unit(value, is_dose=True)
2817
- except ValueError as e:
2818
- errors.append(f"Row {idx} - {str(e)}")
2819
-
2820
- return errors
2821
-
2822
-
2823
- class TimeHandler:
2824
- """Handler for time-related operations."""
2825
-
2826
- VALID_UNITS = {"h", "m", "s", "d", "y"}
2827
-
2828
- @classmethod
2829
- def validate_unit(cls, unit: str) -> bool:
2830
- """Validate if the time unit is acceptable."""
2831
- return unit == unit.lower() and unit in cls.VALID_UNITS
2832
-
2833
- @classmethod
2834
- def standardize_unit(cls, unit: str) -> str:
2835
- """Standardize time unit to standard formats."""
2836
- if unit.startswith("hr"):
2837
- return "h"
2838
- elif unit.startswith("min"):
2839
- return "m"
2840
- elif unit.startswith("sec"):
2841
- return "s"
2842
- return unit[0].lower()
2843
-
2844
- @classmethod
2845
- def validate_values(cls, values: pd.Series) -> list[str]:
2846
- """Validate pert_time values."""
2847
- errors = []
2848
-
2849
- for idx, value in values.items():
2850
- if pd.isna(value):
2851
- continue
2852
-
2853
- if isinstance(value, (int, float)):
2854
- errors.append(
2855
- f"Row {idx} - Missing unit for time: {value}. Must include a unit (h, m, s, d, y)"
2856
- )
2857
- continue
2858
-
2859
- try:
2860
- ValueUnit.parse_value_unit(value, is_dose=False)
2861
- except ValueError as e:
2862
- errors.append(f"Row {idx} - {str(e)}")
2863
-
2864
- return errors
2865
-
2866
-
2867
- class PertAnnDataCatManager(CellxGeneAnnDataCatManager):
2868
- """Categorical manager for `AnnData` to manage perturbations."""
2869
-
2870
- PERT_COLUMNS = {"compound", "genetic", "biologic", "physical"}
2871
-
2872
- def __init__(
2873
- self,
2874
- adata: ad.AnnData,
2875
- organism: Literal["human", "mouse"] = "human",
2876
- pert_dose: bool = True,
2877
- pert_time: bool = True,
2878
- *,
2879
- cxg_schema_version: Literal["5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
2880
- ):
2881
- """Initialize the curator with configuration and validation settings."""
2882
- self._pert_time = pert_time
2883
- self._pert_dose = pert_dose
2884
-
2885
- self._validate_initial_data(adata)
2886
- categoricals, categoricals_defaults = self._configure_categoricals(adata)
2887
-
2888
- super().__init__(
2889
- adata=adata,
2890
- categoricals=categoricals,
2891
- defaults=categoricals_defaults,
2892
- extra_sources=self._configure_sources(adata),
2893
- schema_version=cxg_schema_version,
2894
- )
2895
-
2896
- def _configure_categoricals(self, adata: ad.AnnData):
2897
- """Set up default configuration values."""
2898
- import bionty as bt
2899
- import wetlab as wl
2900
-
2901
- categoricals = CellxGeneAnnDataCatManager._get_cxg_categoricals() | {
2902
- k: v
2903
- for k, v in {
2904
- "cell_line": bt.CellLine.name,
2905
- "pert_target": wl.PerturbationTarget.name,
2906
- "pert_genetic": wl.GeneticPerturbation.name,
2907
- "pert_compound": wl.Compound.name,
2908
- "pert_biologic": wl.Biologic.name,
2909
- "pert_physical": wl.EnvironmentalPerturbation.name,
2910
- }.items()
2911
- if k in adata.obs.columns
2912
- }
2913
- # if "donor_id" in categoricals:
2914
- # categoricals["donor_id"] = Donor.name
2915
-
2916
- categoricals_defaults = CellxGeneAnnDataCatManager.cxg_categoricals_defaults | {
2917
- "cell_line": "unknown",
2918
- "pert_target": "unknown",
2919
- }
2920
-
2921
- return categoricals, categoricals_defaults
2922
-
2923
- def _configure_sources(self, adata: ad.AnnData):
2924
- """Set up data sources."""
2925
- import bionty as bt
2926
- import wetlab as wl
2927
-
2928
- sources = {}
2929
- # # do not yet specify cell_line source
2930
- # if "cell_line" in adata.obs.columns:
2931
- # sources["cell_line"] = bt.Source.filter(
2932
- # entity="bionty.CellLine", name="depmap"
2933
- # ).first()
2934
- if "pert_compound" in adata.obs.columns:
2935
- with logger.mute():
2936
- chebi_source = bt.Source.filter(
2937
- entity="wetlab.Compound", name="chebi"
2938
- ).first()
2939
- if not chebi_source:
2940
- wl.Compound.add_source(
2941
- bt.Source.filter(entity="Drug", name="chebi").first()
2942
- )
2943
-
2944
- sources["pert_compound"] = bt.Source.filter(
2945
- entity="wetlab.Compound", name="chebi"
2946
- ).first()
2947
- return sources
2948
-
2949
- def _validate_initial_data(self, adata: ad.AnnData):
2950
- """Validate the initial data structure."""
2951
- self._validate_required_columns(adata)
2952
- self._validate_perturbation_types(adata)
2953
-
2954
- def _validate_required_columns(self, adata: ad.AnnData):
2955
- """Validate required columns are present."""
2956
- if "pert_target" not in adata.obs.columns:
2957
- if (
2958
- "pert_name" not in adata.obs.columns
2959
- or "pert_type" not in adata.obs.columns
2960
- ):
2961
- raise ValidationError(
2962
- "either 'pert_target' or both 'pert_name' and 'pert_type' must be present"
2963
- )
2964
- else:
2965
- if "pert_name" not in adata.obs.columns:
2966
- logger.warning(
2967
- "no 'pert' column found in adata.obs, will only curate 'pert_target'"
2968
- )
2969
- elif "pert_type" not in adata.obs.columns:
2970
- raise ValidationError("both 'pert' and 'pert_type' must be present")
2971
-
2972
- def _validate_perturbation_types(self, adata: ad.AnnData):
2973
- """Validate perturbation types."""
2974
- if "pert_type" in adata.obs.columns:
2975
- data_pert_types = set(adata.obs["pert_type"].unique())
2976
- invalid_pert_types = data_pert_types - self.PERT_COLUMNS
2977
- if invalid_pert_types:
2978
- raise ValidationError(
2979
- f"invalid pert_type found: {invalid_pert_types}!\n"
2980
- f" → allowed values: {self.PERT_COLUMNS}"
2981
- )
2982
- self._process_perturbation_types(adata, data_pert_types)
2983
-
2984
- def _process_perturbation_types(self, adata: ad.AnnData, pert_types: set):
2985
- """Process and map perturbation types."""
2986
- for pert_type in pert_types:
2987
- col_name = "pert_" + pert_type
2988
- adata.obs[col_name] = adata.obs["pert_name"].where(
2989
- adata.obs["pert_type"] == pert_type, None
2990
- )
2991
- if adata.obs[col_name].dtype.name == "category":
2992
- adata.obs[col_name].cat.remove_unused_categories()
2993
- logger.important(f"mapped 'pert_name' to '{col_name}'")
2994
-
2995
- def validate(self) -> bool: # type: ignore
2996
- """Validate the AnnData object."""
2997
- validated = super().validate()
2998
-
2999
- if self._pert_dose:
3000
- validated &= self._validate_dose_column()
3001
- if self._pert_time:
3002
- validated &= self._validate_time_column()
3003
-
3004
- self._is_validated = validated
3005
-
3006
- # sort columns
3007
- first_columns = [
3008
- "pert_target",
3009
- "pert_genetic",
3010
- "pert_compound",
3011
- "pert_biologic",
3012
- "pert_physical",
3013
- "pert_dose",
3014
- "pert_time",
3015
- "organism",
3016
- "cell_line",
3017
- "cell_type",
3018
- "disease",
3019
- "tissue_type",
3020
- "tissue",
3021
- "assay",
3022
- "suspension_type",
3023
- "donor_id",
3024
- "sex",
3025
- "self_reported_ethnicity",
3026
- "development_stage",
3027
- "pert_name",
3028
- "pert_type",
3029
- ]
3030
- sorted_columns = [
3031
- col for col in first_columns if col in self._adata.obs.columns
3032
- ] + [col for col in self._adata.obs.columns if col not in first_columns]
3033
- # must assign to self._df to ensure .standardize works correctly
3034
- self._obs_df = self._adata.obs[sorted_columns]
3035
- self._adata.obs = self._obs_df
3036
- return validated
3037
-
3038
- def standardize(self, key: str) -> pd.DataFrame:
3039
- """Standardize the AnnData object."""
3040
- super().standardize(key)
3041
- self._adata.obs = self._obs_df
3042
-
3043
- def _validate_dose_column(self) -> bool:
3044
- """Validate the dose column."""
3045
- if not Feature.filter(name="pert_dose").exists():
3046
- Feature(name="pert_dose", dtype="str").save() # type: ignore
3047
-
3048
- dose_errors = DoseHandler.validate_values(self._adata.obs["pert_dose"])
3049
- if dose_errors:
3050
- self._log_validation_errors("pert_dose", dose_errors)
3051
- return False
3052
- return True
3053
-
3054
- def _validate_time_column(self) -> bool:
3055
- """Validate the time column."""
3056
- if not Feature.filter(name="pert_time").exists():
3057
- Feature(name="pert_time", dtype="str").save() # type: ignore
3058
-
3059
- time_errors = TimeHandler.validate_values(self._adata.obs["pert_time"])
3060
- if time_errors:
3061
- self._log_validation_errors("pert_time", time_errors)
3062
- return False
3063
- return True
3064
-
3065
- def _log_validation_errors(self, column: str, errors: list):
3066
- """Log validation errors with formatting."""
3067
- errors_print = "\n ".join(errors)
3068
- logger.warning(
3069
- f"invalid {column} values found!\n {errors_print}\n"
3070
- f" → run {colors.cyan('standardize_dose_time()')}"
3071
- )
3072
-
3073
- def standardize_dose_time(self) -> pd.DataFrame:
3074
- """Standardize dose and time values."""
3075
- standardized_df = self._adata.obs.copy()
3076
-
3077
- if "pert_dose" in self._adata.obs.columns:
3078
- standardized_df = self._standardize_column(
3079
- standardized_df, "pert_dose", is_dose=True
3080
- )
3081
-
3082
- if "pert_time" in self._adata.obs.columns:
3083
- standardized_df = self._standardize_column(
3084
- standardized_df, "pert_time", is_dose=False
3085
- )
3086
-
3087
- self._adata.obs = standardized_df
3088
- return standardized_df
3089
-
3090
- def _standardize_column(
3091
- self, df: pd.DataFrame, column: str, is_dose: bool
3092
- ) -> pd.DataFrame:
3093
- """Standardize values in a specific column."""
3094
- for idx, value in self._adata.obs[column].items():
3095
- if pd.isna(value) or (
3096
- isinstance(value, str) and (not value.strip() or value.lower() == "nan")
3097
- ):
3098
- df.at[idx, column] = None
3099
- continue
3100
-
3101
- try:
3102
- num, unit = ValueUnit.parse_value_unit(value, is_dose=is_dose)
3103
- df.at[idx, column] = f"{num}{unit}"
3104
- except ValueError:
3105
- continue
3106
-
3107
- return df
3108
-
3109
-
3110
- def get_current_filter_kwargs(registry: type[Record], kwargs: dict) -> dict:
3111
- """Make sure the source and organism are saved in the same database as the registry."""
3112
- db = registry.filter().db
3113
- source = kwargs.get("source")
3114
- organism = kwargs.get("organism")
3115
- filter_kwargs = kwargs.copy()
3116
-
3117
- if isinstance(organism, Record) and organism._state.db != "default":
3118
- if db is None or db == "default":
3119
- organism_default = copy.copy(organism)
3120
- # save the organism record in the default database
3121
- organism_default.save()
3122
- filter_kwargs["organism"] = organism_default
3123
- if isinstance(source, Record) and source._state.db != "default":
3124
- if db is None or db == "default":
3125
- source_default = copy.copy(source)
3126
- # save the source record in the default database
3127
- source_default.save()
3128
- filter_kwargs["source"] = source_default
3129
-
3130
- return filter_kwargs
3131
-
3132
-
3133
- def get_organism_kwargs(
3134
- field: FieldAttr, organism: str | None = None, values: Any = None
3135
- ) -> dict[str, str]:
3136
- """Check if a registry needs an organism and return the organism name."""
3137
- registry = field.field.model
3138
- if registry.__base__.__name__ == "BioRecord":
3139
- import bionty as bt
3140
- from bionty._organism import is_organism_required
3141
-
3142
- from ..models._from_values import get_organism_record_from_field
3143
-
3144
- if is_organism_required(registry):
3145
- if organism is not None or bt.settings.organism is not None:
3146
- return {"organism": organism or bt.settings.organism.name}
3147
- else:
3148
- organism_record = get_organism_record_from_field(
3149
- field, organism=organism, values=values
3150
- )
3151
- if organism_record is not None:
3152
- return {"organism": organism_record.name}
3153
- return {}
3154
-
3155
-
3156
- def annotate_artifact(
3157
- artifact: Artifact,
3158
- *,
3159
- schema: Schema | None = None,
3160
- cat_columns: dict[str, CatColumn] | None = None,
3161
- index_field: FieldAttr | dict[str, FieldAttr] | None = None,
3162
- **kwargs,
3163
- ) -> Artifact:
3164
- from ..models.artifact import add_labels
3165
-
3166
- if cat_columns is None:
3167
- cat_columns = {}
3168
-
3169
- # annotate with labels
3170
- for key, cat_column in cat_columns.items():
3171
- if (
3172
- cat_column._field.field.model == Feature
3173
- or key == "columns"
3174
- or key == "var_index"
3175
- ):
3176
- continue
3177
- add_labels(
3178
- artifact,
3179
- records=cat_column.labels,
3180
- feature=cat_column.feature,
3181
- feature_ref_is_name=None, # do not need anymore
3182
- label_ref_is_name=cat_column.label_ref_is_name,
3183
- from_curator=True,
3184
- )
3185
-
3186
- # annotate with inferred feature sets
3187
- match artifact.otype:
3188
- case "DataFrame":
3189
- artifact.features._add_set_from_df(field=index_field) # type: ignore
3190
- case "AnnData":
3191
- if schema is not None and "uns" in schema.slots:
3192
- uns_field = parse_cat_dtype(schema.slots["uns"].itype, is_itype=True)[
3193
- "field"
3194
- ]
3195
- else:
3196
- uns_field = None
3197
- artifact.features._add_set_from_anndata( # type: ignore
3198
- var_field=index_field, uns_field=uns_field
3199
- )
3200
- case "MuData":
3201
- artifact.features._add_set_from_mudata(var_fields=index_field) # type: ignore
3202
- case "SpatialData":
3203
- artifact.features._add_set_from_spatialdata( # type: ignore
3204
- sample_metadata_key=kwargs.get("sample_metadata_key", "sample"),
3205
- var_fields=index_field,
3206
- )
3207
- case _:
3208
- raise NotImplementedError # pragma: no cover
3209
-
3210
- slug = ln_setup.settings.instance.slug
3211
- if ln_setup.settings.instance.is_remote: # pdagma: no cover
3212
- logger.important(f"go to https://lamin.ai/{slug}/artifact/{artifact.uid}")
3213
- return artifact
3214
-
3215
-
3216
- # TODO: need this function to support mutli-value columns
3217
- def _flatten_unique(series: pd.Series[list[Any] | Any]) -> list[Any]:
3218
- """Flatten a Pandas series containing lists or single items into a unique list of elements."""
3219
- result = set()
3220
-
3221
- for item in series:
3222
- if isinstance(item, list):
3223
- result.update(item)
3224
- else:
3225
- result.add(item)
3226
-
3227
- return list(result)
3228
-
3229
-
3230
- def _save_organism(name: str):
3231
- """Save an organism record."""
3232
- import bionty as bt
3233
-
3234
- organism = bt.Organism.filter(name=name).one_or_none()
3235
- if organism is None:
3236
- organism = bt.Organism.from_source(name=name)
3237
- if organism is None:
3238
- raise ValidationError(
3239
- f'Organism "{name}" not found from public reference\n'
3240
- f' → please save it from a different source: bt.Organism.from_source(name="{name}", source).save()'
3241
- f' → or manually save it without source: bt.Organism(name="{name}").save()'
3242
- )
3243
- organism.save()
3244
- return organism
3245
-
3246
-
3247
- def _ref_is_name(field: FieldAttr | None) -> bool | None:
3248
- """Check if the reference field is a name field."""
3249
- from ..models.can_curate import get_name_field
3250
-
3251
- if field is not None:
3252
- name_field = get_name_field(field.field.model)
3253
- return field.field.name == name_field
3254
- return None
3255
-
3256
-
3257
- # backward compat constructors ------------------
3258
-
3259
-
3260
- @classmethod # type: ignore
3261
- def from_df(
3262
- cls,
3263
- df: pd.DataFrame,
3264
- categoricals: dict[str, FieldAttr] | None = None,
3265
- columns: FieldAttr = Feature.name,
3266
- organism: str | None = None,
3267
- ) -> DataFrameCatManager:
3268
- if organism is not None:
3269
- logger.warning("organism is ignored, define it on the dtype level")
3270
- return DataFrameCatManager(
3271
- df=df,
3272
- categoricals=categoricals,
3273
- columns=columns,
3274
- )
3275
-
3276
-
3277
- @classmethod # type: ignore
3278
- def from_anndata(
3279
- cls,
3280
- data: ad.AnnData | UPathStr,
3281
- var_index: FieldAttr,
3282
- categoricals: dict[str, FieldAttr] | None = None,
3283
- obs_columns: FieldAttr = Feature.name,
3284
- organism: str | None = None,
3285
- sources: dict[str, Record] | None = None,
3286
- ) -> AnnDataCatManager:
3287
- if organism is not None:
3288
- logger.warning("organism is ignored, define it on the dtype level")
3289
- return AnnDataCatManager(
3290
- data=data,
3291
- var_index=var_index,
3292
- categoricals=categoricals,
3293
- obs_columns=obs_columns,
3294
- sources=sources,
3295
- )
3296
-
3297
-
3298
- @classmethod # type: ignore
3299
- def from_mudata(
3300
- cls,
3301
- mdata: MuData | UPathStr,
3302
- var_index: dict[str, dict[str, FieldAttr]],
3303
- categoricals: dict[str, FieldAttr] | None = None,
3304
- organism: str | None = None,
3305
- ) -> MuDataCatManager:
3306
- if not is_package_installed("mudata"):
3307
- raise ImportError("Please install mudata: pip install mudata")
3308
- if organism is not None:
3309
- logger.warning("organism is ignored, define it on the dtype level")
3310
- return MuDataCatManager(
3311
- mdata=mdata,
3312
- var_index=var_index,
3313
- categoricals=categoricals,
3314
- )
3315
-
3316
-
3317
- @classmethod # type: ignore
3318
- def from_tiledbsoma(
3319
- cls,
3320
- experiment_uri: UPathStr,
3321
- var_index: dict[str, tuple[str, FieldAttr]],
3322
- categoricals: dict[str, FieldAttr] | None = None,
3323
- obs_columns: FieldAttr = Feature.name,
3324
- organism: str | None = None,
3325
- sources: dict[str, Record] | None = None,
3326
- ) -> TiledbsomaCatManager:
3327
- if organism is not None:
3328
- logger.warning("organism is ignored, define it on the dtype level")
3329
- return TiledbsomaCatManager(
3330
- experiment_uri=experiment_uri,
3331
- var_index=var_index,
3332
- categoricals=categoricals,
3333
- obs_columns=obs_columns,
3334
- sources=sources,
3335
- )
3336
-
3337
-
3338
- @classmethod # type: ignore
3339
- def from_spatialdata(
3340
- cls,
3341
- sdata: SpatialData | UPathStr,
3342
- var_index: dict[str, FieldAttr],
3343
- categoricals: dict[str, dict[str, FieldAttr]] | None = None,
3344
- organism: str | None = None,
3345
- sources: dict[str, dict[str, Record]] | None = None,
3346
- *,
3347
- sample_metadata_key: str = "sample",
3348
- ):
3349
- if not is_package_installed("spatialdata"):
3350
- raise ImportError("Please install spatialdata: pip install spatialdata")
3351
- if organism is not None:
3352
- logger.warning("organism is ignored, define it on the dtype level")
3353
- return SpatialDataCatManager(
3354
- sdata=sdata,
3355
- var_index=var_index,
3356
- categoricals=categoricals,
3357
- sources=sources,
3358
- sample_metadata_key=sample_metadata_key,
3359
- )
3360
-
3361
-
3362
- CatManager.from_df = from_df # type: ignore
3363
- CatManager.from_anndata = from_anndata # type: ignore
3364
- CatManager.from_mudata = from_mudata # type: ignore
3365
- CatManager.from_spatialdata = from_spatialdata # type: ignore
3366
- CatManager.from_tiledbsoma = from_tiledbsoma # type: ignore
24
+ from .core import AnnDataCurator, DataFrameCurator, MuDataCurator, SpatialDataCurator