lamindb 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. lamindb/__init__.py +52 -36
  2. lamindb/_finish.py +17 -10
  3. lamindb/_tracked.py +1 -1
  4. lamindb/base/__init__.py +3 -1
  5. lamindb/base/fields.py +40 -22
  6. lamindb/base/ids.py +1 -94
  7. lamindb/base/types.py +2 -0
  8. lamindb/base/uids.py +117 -0
  9. lamindb/core/_context.py +216 -133
  10. lamindb/core/_settings.py +38 -25
  11. lamindb/core/datasets/__init__.py +11 -4
  12. lamindb/core/datasets/_core.py +5 -5
  13. lamindb/core/datasets/_small.py +0 -93
  14. lamindb/core/datasets/mini_immuno.py +172 -0
  15. lamindb/core/loaders.py +1 -1
  16. lamindb/core/storage/_backed_access.py +100 -6
  17. lamindb/core/storage/_polars_lazy_df.py +51 -0
  18. lamindb/core/storage/_pyarrow_dataset.py +15 -30
  19. lamindb/core/storage/objects.py +6 -0
  20. lamindb/core/subsettings/__init__.py +2 -0
  21. lamindb/core/subsettings/_annotation_settings.py +11 -0
  22. lamindb/curators/__init__.py +7 -3559
  23. lamindb/curators/_legacy.py +2056 -0
  24. lamindb/curators/core.py +1546 -0
  25. lamindb/errors.py +11 -0
  26. lamindb/examples/__init__.py +27 -0
  27. lamindb/examples/schemas/__init__.py +12 -0
  28. lamindb/examples/schemas/_anndata.py +25 -0
  29. lamindb/examples/schemas/_simple.py +19 -0
  30. lamindb/integrations/_vitessce.py +8 -5
  31. lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
  32. lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
  33. lamindb/models/__init__.py +12 -2
  34. lamindb/models/_describe.py +21 -4
  35. lamindb/models/_feature_manager.py +384 -301
  36. lamindb/models/_from_values.py +1 -1
  37. lamindb/models/_is_versioned.py +5 -15
  38. lamindb/models/_label_manager.py +8 -2
  39. lamindb/models/artifact.py +354 -177
  40. lamindb/models/artifact_set.py +122 -0
  41. lamindb/models/can_curate.py +4 -1
  42. lamindb/models/collection.py +79 -56
  43. lamindb/models/core.py +1 -1
  44. lamindb/models/feature.py +78 -47
  45. lamindb/models/has_parents.py +24 -9
  46. lamindb/models/project.py +3 -3
  47. lamindb/models/query_manager.py +221 -22
  48. lamindb/models/query_set.py +251 -206
  49. lamindb/models/record.py +211 -344
  50. lamindb/models/run.py +59 -5
  51. lamindb/models/save.py +9 -5
  52. lamindb/models/schema.py +673 -196
  53. lamindb/models/transform.py +5 -14
  54. lamindb/models/ulabel.py +8 -5
  55. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/METADATA +8 -7
  56. lamindb-1.5.0.dist-info/RECORD +108 -0
  57. lamindb-1.3.2.dist-info/RECORD +0 -95
  58. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/LICENSE +0 -0
  59. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/WHEEL +0 -0
@@ -5,3572 +5,20 @@
5
5
 
6
6
  DataFrameCurator
7
7
  AnnDataCurator
8
- MuDataCurator
9
8
  SpatialDataCurator
9
+ MuDataCurator
10
10
 
11
- Helper classes.
11
+ Modules.
12
12
 
13
13
  .. autosummary::
14
14
  :toctree: .
15
15
 
16
- Curator
17
- SlotsCurator
18
- CatManager
19
- CatLookup
20
- DataFrameCatManager
21
- AnnDataCatManager
22
- MuDataCatManager
23
- SpatialDataCatManager
24
- TiledbsomaCatManager
16
+ core
25
17
 
26
18
  """
27
19
 
28
- from __future__ import annotations
29
-
30
- import copy
31
- import re
32
- from itertools import chain
33
- from typing import TYPE_CHECKING, Any, Callable, Literal
34
-
35
- import anndata as ad
36
- import lamindb_setup as ln_setup
37
- import pandas as pd
38
- import pandera
39
- import pyarrow as pa
40
- from lamin_utils import colors, logger
41
- from lamindb_setup.core import deprecated
42
- from lamindb_setup.core._docs import doc_args
43
- from lamindb_setup.core.upath import UPath
44
-
45
- if TYPE_CHECKING:
46
- from lamindb_setup.core.types import UPathStr
47
- from mudata import MuData
48
- from spatialdata import SpatialData
49
-
50
- from lamindb.core.types import ScverseDataStructures
51
- from lamindb.models import Record
52
- from lamindb.base.types import FieldAttr # noqa
53
- from lamindb.core._settings import settings
54
- from lamindb.models import (
55
- Artifact,
56
- Feature,
57
- Record,
58
- Run,
59
- Schema,
60
- ULabel,
20
+ from ._legacy import ( # backward compat
21
+ CellxGeneAnnDataCatManager,
22
+ PertAnnDataCatManager,
61
23
  )
62
- from lamindb.models.artifact import (
63
- add_labels,
64
- data_is_anndata,
65
- data_is_mudata,
66
- data_is_spatialdata,
67
- )
68
- from lamindb.models.feature import parse_dtype, parse_cat_dtype
69
- from lamindb.models._from_values import _format_values
70
-
71
- from ..errors import InvalidArgument, ValidationError
72
- from anndata import AnnData
73
-
74
- if TYPE_CHECKING:
75
- from collections.abc import Iterable, MutableMapping
76
- from typing import Any
77
-
78
- from lamindb_setup.core.types import UPathStr
79
-
80
- from lamindb.models.query_set import RecordList
81
-
82
-
83
- def strip_ansi_codes(text):
84
- # This pattern matches ANSI escape sequences
85
- ansi_pattern = re.compile(r"\x1b\[[0-9;]*m")
86
- return ansi_pattern.sub("", text)
87
-
88
-
89
- class CatLookup:
90
- """Lookup categories from the reference instance.
91
-
92
- Args:
93
- categoricals: A dictionary of categorical fields to lookup.
94
- slots: A dictionary of slot fields to lookup.
95
- public: Whether to lookup from the public instance. Defaults to False.
96
-
97
- Example::
98
-
99
- curator = ln.curators.DataFrameCurator(...)
100
- curator.cat.lookup()["cell_type"].alveolar_type_1_fibroblast_cell
101
-
102
- """
103
-
104
- def __init__(
105
- self,
106
- categoricals: dict[str, FieldAttr],
107
- slots: dict[str, FieldAttr] = None,
108
- public: bool = False,
109
- organism: str | None = None,
110
- sources: dict[str, Record] | None = None,
111
- ) -> None:
112
- slots = slots or {}
113
- self._categoricals = {**categoricals, **slots}
114
- self._public = public
115
- self._organism = organism
116
- self._sources = sources
117
-
118
- def __getattr__(self, name):
119
- if name in self._categoricals:
120
- registry = self._categoricals[name].field.model
121
- if self._public and hasattr(registry, "public"):
122
- return registry.public(
123
- organism=self._organism, source=self._sources.get(name)
124
- ).lookup()
125
- else:
126
- return registry.lookup()
127
- raise AttributeError(
128
- f'"{self.__class__.__name__}" object has no attribute "{name}"'
129
- )
130
-
131
- def __getitem__(self, name):
132
- if name in self._categoricals:
133
- registry = self._categoricals[name].field.model
134
- if self._public and hasattr(registry, "public"):
135
- return registry.public(
136
- organism=self._organism, source=self._sources.get(name)
137
- ).lookup()
138
- else:
139
- return registry.lookup()
140
- raise AttributeError(
141
- f'"{self.__class__.__name__}" object has no attribute "{name}"'
142
- )
143
-
144
- def __repr__(self) -> str:
145
- if len(self._categoricals) > 0:
146
- getattr_keys = "\n ".join(
147
- [f".{key}" for key in self._categoricals if key.isidentifier()]
148
- )
149
- getitem_keys = "\n ".join(
150
- [str([key]) for key in self._categoricals if not key.isidentifier()]
151
- )
152
- ref = "public" if self._public else "registries"
153
- return (
154
- f"Lookup objects from the {colors.italic(ref)}:\n "
155
- f"{colors.green(getattr_keys)}\n "
156
- f"{colors.green(getitem_keys)}\n"
157
- 'Example:\n → categories = curator.lookup()["cell_type"]\n'
158
- " → categories.alveolar_type_1_fibroblast_cell\n\n"
159
- "To look up public ontologies, use .lookup(public=True)"
160
- )
161
- else: # pragma: no cover
162
- return colors.warning("No fields are found!")
163
-
164
-
165
- CAT_MANAGER_DOCSTRING = """Manage categoricals by updating registries."""
166
-
167
-
168
- SLOTS_DOCSTRING = """Curator objects by slot.
169
-
170
- .. versionadded:: 1.1.1
171
- """
172
-
173
-
174
- VALIDATE_DOCSTRING = """Validate dataset against Schema.
175
-
176
- Raises:
177
- lamindb.errors.ValidationError: If validation fails.
178
- """
179
-
180
- SAVE_ARTIFACT_DOCSTRING = """Save an annotated artifact.
181
-
182
- Args:
183
- key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
184
- description: A description.
185
- revises: Previous version of the artifact. Is an alternative way to passing `key` to trigger a new version.
186
- run: The run that creates the artifact.
187
-
188
- Returns:
189
- A saved artifact record.
190
- """
191
-
192
-
193
- class Curator:
194
- """Curator base class.
195
-
196
- A `Curator` object makes it easy to validate, standardize & annotate datasets.
197
-
198
- See:
199
- - :class:`~lamindb.curators.DataFrameCurator`
200
- - :class:`~lamindb.curators.AnnDataCurator`
201
- - :class:`~lamindb.curators.MuDataCurator`
202
- - :class:`~lamindb.curators.SpatialDataCurator`
203
-
204
- .. versionadded:: 1.1.0
205
- """
206
-
207
- def __init__(self, dataset: Any, schema: Schema | None = None):
208
- self._artifact: Artifact = None # pass the dataset as an artifact
209
- self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
210
- if isinstance(self._dataset, Artifact):
211
- self._artifact = self._dataset
212
- if self._artifact.otype in {
213
- "DataFrame",
214
- "AnnData",
215
- "MuData",
216
- "SpatialData",
217
- }:
218
- self._dataset = self._dataset.load()
219
- self._schema: Schema | None = schema
220
- self._is_validated: bool = False
221
- self._cat_manager: CatManager = None # is None for CatManager curators
222
-
223
- @doc_args(VALIDATE_DOCSTRING)
224
- def validate(self) -> bool | str:
225
- """{}""" # noqa: D415
226
- pass # pragma: no cover
227
-
228
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
229
- def save_artifact(
230
- self,
231
- *,
232
- key: str | None = None,
233
- description: str | None = None,
234
- revises: Artifact | None = None,
235
- run: Run | None = None,
236
- ) -> Artifact:
237
- """{}""" # noqa: D415
238
- # Note that this docstring has to be consistent with the Artifact()
239
- # constructor signature
240
- pass # pragma: no cover
241
-
242
-
243
- class SlotsCurator(Curator):
244
- """Curator for a dataset with slots.
245
-
246
- Args:
247
- dataset: The dataset to validate & annotate.
248
- schema: A `Schema` object that defines the validation constraints.
249
-
250
- .. versionadded:: 1.3.0
251
- """
252
-
253
- def __init__(
254
- self,
255
- dataset: Any,
256
- schema: Schema,
257
- ) -> None:
258
- super().__init__(dataset=dataset, schema=schema)
259
- self._slots: dict[str, DataFrameCurator] = {}
260
-
261
- # used in MuDataCurator and SpatialDataCurator
262
- # in form of {table/modality_key: var_field}
263
- self._var_fields: dict[str, FieldAttr] = {}
264
- # in form of {table/modality_key: categoricals}
265
- self._categoricals: dict[str, dict[str, FieldAttr]] = {}
266
-
267
- @property
268
- @doc_args(SLOTS_DOCSTRING)
269
- def slots(self) -> dict[str, DataFrameCurator]:
270
- """{}""" # noqa: D415
271
- return self._slots
272
-
273
- @doc_args(VALIDATE_DOCSTRING)
274
- def validate(self) -> None:
275
- """{}""" # noqa: D415
276
- for _, curator in self._slots.items():
277
- curator.validate()
278
-
279
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
280
- def save_artifact(
281
- self,
282
- *,
283
- key: str | None = None,
284
- description: str | None = None,
285
- revises: Artifact | None = None,
286
- run: Run | None = None,
287
- ) -> Artifact:
288
- """{}""" # noqa: D415
289
- if not self._is_validated:
290
- self.validate()
291
-
292
- # default implementation for MuDataCurator and SpatialDataCurator
293
- return save_artifact( # type: ignore
294
- self._dataset,
295
- key=key,
296
- description=description,
297
- fields=self._categoricals,
298
- index_field=self._var_fields,
299
- artifact=self._artifact,
300
- revises=revises,
301
- run=run,
302
- schema=self._schema,
303
- )
304
-
305
-
306
- def check_dtype(expected_type) -> Callable:
307
- """Creates a check function for Pandera that validates a column's dtype.
308
-
309
- Args:
310
- expected_type: String identifier for the expected type ('int', 'float', or 'num')
311
-
312
- Returns:
313
- A function that checks if a series has the expected dtype
314
- """
315
-
316
- def check_function(series):
317
- if expected_type == "int":
318
- is_valid = pd.api.types.is_integer_dtype(series.dtype)
319
- elif expected_type == "float":
320
- is_valid = pd.api.types.is_float_dtype(series.dtype)
321
- elif expected_type == "num":
322
- is_valid = pd.api.types.is_numeric_dtype(series.dtype)
323
- return is_valid
324
-
325
- return check_function
326
-
327
-
328
- class DataFrameCurator(Curator):
329
- # the example in the docstring is tested in test_curators_quickstart_example
330
- """Curator for `DataFrame`.
331
-
332
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
333
-
334
- .. versionadded:: 1.1.0
335
-
336
- Args:
337
- dataset: The DataFrame-like object to validate & annotate.
338
- schema: A `Schema` object that defines the validation constraints.
339
-
340
- Example::
341
-
342
- import lamindb as ln
343
- import bionty as bt
344
-
345
- # define valid labels
346
- perturbation = ln.ULabel(name="Perturbation", is_type=True).save()
347
- ln.ULabel(name="DMSO", type=perturbation).save()
348
- ln.ULabel(name="IFNG", type=perturbation).save()
349
- bt.CellType.from_source(name="B cell").save()
350
- bt.CellType.from_source(name="T cell").save()
351
-
352
- # define schema
353
- schema = ln.Schema(
354
- name="small_dataset1_obs_level_metadata",
355
- features=[
356
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
357
- ln.Feature(name="sample_note", dtype=str).save(),
358
- ln.Feature(name="cell_type_by_expert", dtype=bt.CellType).save(),
359
- ln.Feature(name="cell_type_by_model", dtype=bt.CellType).save(),
360
- ],
361
- ).save()
362
-
363
- # curate a DataFrame
364
- df = datasets.small_dataset1(otype="DataFrame")
365
- curator = ln.curators.DataFrameCurator(df, schema)
366
- artifact = curator.save_artifact(key="example_datasets/dataset1.parquet")
367
- assert artifact.schema == schema
368
- """
369
-
370
- def __init__(
371
- self,
372
- dataset: pd.DataFrame | Artifact,
373
- schema: Schema,
374
- ) -> None:
375
- super().__init__(dataset=dataset, schema=schema)
376
- categoricals = {}
377
- if schema.n > 0:
378
- # populate features
379
- pandera_columns = {}
380
- for feature in schema.features.all():
381
- if feature.dtype in {"int", "float", "num"}:
382
- dtype = (
383
- self._dataset[feature.name].dtype
384
- if feature.name in self._dataset.columns
385
- else None
386
- )
387
- pandera_columns[feature.name] = pandera.Column(
388
- dtype=None,
389
- checks=pandera.Check(
390
- check_dtype(feature.dtype),
391
- element_wise=False,
392
- error=f"Column '{feature.name}' failed dtype check for '{feature.dtype}': got {dtype}",
393
- ),
394
- nullable=feature.nullable,
395
- coerce=feature.coerce_dtype,
396
- )
397
- else:
398
- pandera_dtype = (
399
- feature.dtype
400
- if not feature.dtype.startswith("cat")
401
- else "category"
402
- )
403
- pandera_columns[feature.name] = pandera.Column(
404
- pandera_dtype,
405
- nullable=feature.nullable,
406
- coerce=feature.coerce_dtype,
407
- )
408
- if feature.dtype.startswith("cat"):
409
- categoricals[feature.name] = parse_dtype(feature.dtype)[0]["field"]
410
- self._pandera_schema = pandera.DataFrameSchema(
411
- pandera_columns, coerce=schema.coerce_dtype
412
- )
413
- else:
414
- assert schema.itype is not None # noqa: S101
415
- self._cat_manager = DataFrameCatManager(
416
- self._dataset,
417
- columns=parse_cat_dtype(schema.itype, is_itype=True)["field"],
418
- categoricals=categoricals,
419
- )
420
-
421
- @property
422
- @doc_args(CAT_MANAGER_DOCSTRING)
423
- def cat(self) -> CatManager:
424
- """{}""" # noqa: D415
425
- return self._cat_manager
426
-
427
- def standardize(self) -> None:
428
- """Standardize the dataset.
429
-
430
- - Adds missing columns for features
431
- - Fills missing values for features with default values
432
- """
433
- for feature in self._schema.members:
434
- if feature.name not in self._dataset.columns:
435
- if feature.default_value is not None or feature.nullable:
436
- fill_value = (
437
- feature.default_value
438
- if feature.default_value is not None
439
- else pd.NA
440
- )
441
- if feature.dtype.startswith("cat"):
442
- self._dataset[feature.name] = pd.Categorical(
443
- [fill_value] * len(self._dataset)
444
- )
445
- else:
446
- self._dataset[feature.name] = fill_value
447
- logger.important(
448
- f"added column {feature.name} with fill value {fill_value}"
449
- )
450
- else:
451
- raise ValidationError(
452
- f"Missing column {feature.name} cannot be added because is not nullable and has no default value"
453
- )
454
- else:
455
- if feature.default_value is not None:
456
- if isinstance(
457
- self._dataset[feature.name].dtype, pd.CategoricalDtype
458
- ):
459
- if (
460
- feature.default_value
461
- not in self._dataset[feature.name].cat.categories
462
- ):
463
- self._dataset[feature.name] = self._dataset[
464
- feature.name
465
- ].cat.add_categories(feature.default_value)
466
- self._dataset[feature.name] = self._dataset[feature.name].fillna(
467
- feature.default_value
468
- )
469
-
470
- def _cat_manager_validate(self) -> None:
471
- self._cat_manager.validate()
472
- if self._cat_manager._is_validated:
473
- self._is_validated = True
474
- else:
475
- self._is_validated = False
476
- raise ValidationError(self._cat_manager._validate_category_error_messages)
477
-
478
- @doc_args(VALIDATE_DOCSTRING)
479
- def validate(self) -> None:
480
- """{}""" # noqa: D415
481
- if self._schema.n > 0:
482
- try:
483
- # first validate through pandera
484
- self._pandera_schema.validate(self._dataset)
485
- # then validate lamindb categoricals
486
- self._cat_manager_validate()
487
- except pandera.errors.SchemaError as err:
488
- self._is_validated = False
489
- # .exconly() doesn't exist on SchemaError
490
- raise ValidationError(str(err)) from err
491
- else:
492
- self._cat_manager_validate()
493
-
494
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
495
- def save_artifact(
496
- self,
497
- *,
498
- key: str | None = None,
499
- description: str | None = None,
500
- revises: Artifact | None = None,
501
- run: Run | None = None,
502
- ) -> Artifact:
503
- """{}""" # noqa: D415
504
- if not self._is_validated:
505
- self.validate() # raises ValidationError if doesn't validate
506
- result = parse_cat_dtype(self._schema.itype, is_itype=True)
507
- return save_artifact( # type: ignore
508
- self._dataset,
509
- description=description,
510
- fields=self._cat_manager.categoricals,
511
- index_field=result["field"],
512
- key=key,
513
- artifact=self._artifact,
514
- revises=revises,
515
- run=run,
516
- schema=self._schema,
517
- )
518
-
519
-
520
- class AnnDataCurator(SlotsCurator):
521
- # the example in the docstring is tested in test_curators_quickstart_example
522
- """Curator for `AnnData`.
523
-
524
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
525
-
526
- .. versionadded:: 1.1.0
527
-
528
- Args:
529
- dataset: The AnnData-like object to validate & annotate.
530
- schema: A `Schema` object that defines the validation constraints.
531
-
532
- Example::
533
-
534
- import lamindb as ln
535
- import bionty as bt
536
-
537
- # define valid labels
538
- perturbation = ln.ULabel(name="Perturbation", is_type=True).save()
539
- ln.ULabel(name="DMSO", type=perturbation).save()
540
- ln.ULabel(name="IFNG", type=perturbation).save()
541
- bt.CellType.from_source(name="B cell").save()
542
- bt.CellType.from_source(name="T cell").save()
543
-
544
- # define obs schema
545
- obs_schema = ln.Schema(
546
- name="small_dataset1_obs_level_metadata",
547
- features=[
548
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
549
- ln.Feature(name="sample_note", dtype=str).save(),
550
- ln.Feature(name="cell_type_by_expert", dtype=bt.CellType).save(),
551
- ln.Feature(name="cell_type_by_model", dtype=bt.CellType").save(),
552
- ],
553
- ).save()
554
-
555
- # define var schema
556
- var_schema = ln.Schema(
557
- name="scRNA_seq_var_schema",
558
- itype=bt.Gene.ensembl_gene_id,
559
- dtype=int,
560
- ).save()
561
-
562
- # define composite schema
563
- anndata_schema = ln.Schema(
564
- name="small_dataset1_anndata_schema",
565
- otype="AnnData",
566
- components={"obs": obs_schema, "var": var_schema},
567
- ).save()
568
-
569
- # curate an AnnData
570
- adata = ln.core.datasets.small_dataset1(otype="AnnData")
571
- curator = ln.curators.AnnDataCurator(adata, anndata_schema)
572
- artifact = curator.save_artifact(key="example_datasets/dataset1.h5ad")
573
- assert artifact.schema == anndata_schema
574
- """
575
-
576
- def __init__(
577
- self,
578
- dataset: AnnData | Artifact,
579
- schema: Schema,
580
- ) -> None:
581
- super().__init__(dataset=dataset, schema=schema)
582
- if not data_is_anndata(self._dataset):
583
- raise InvalidArgument("dataset must be AnnData-like.")
584
- if schema.otype != "AnnData":
585
- raise InvalidArgument("Schema otype must be 'AnnData'.")
586
- # TODO: also support slots other than obs and var
587
- self._slots = {
588
- slot: DataFrameCurator(
589
- (
590
- getattr(self._dataset, slot).T
591
- if slot == "var"
592
- else getattr(self._dataset, slot)
593
- ),
594
- slot_schema,
595
- )
596
- for slot, slot_schema in schema.slots.items()
597
- if slot in {"obs", "var", "uns"}
598
- }
599
-
600
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
601
- def save_artifact(
602
- self,
603
- *,
604
- key: str | None = None,
605
- description: str | None = None,
606
- revises: Artifact | None = None,
607
- run: Run | None = None,
608
- ) -> Artifact:
609
- """{}""" # noqa: D415
610
- if not self._is_validated:
611
- self.validate()
612
- if "obs" in self.slots:
613
- categoricals = self.slots["obs"]._cat_manager.categoricals
614
- else:
615
- categoricals = {}
616
- return save_artifact( # type: ignore
617
- self._dataset,
618
- description=description,
619
- fields=categoricals,
620
- index_field=(
621
- parse_cat_dtype(self.slots["var"]._schema.itype, is_itype=True)["field"]
622
- if "var" in self._slots
623
- else None
624
- ),
625
- key=key,
626
- artifact=self._artifact,
627
- revises=revises,
628
- run=run,
629
- schema=self._schema,
630
- )
631
-
632
-
633
- def _assign_var_fields_categoricals_multimodal(
634
- modality: str | None,
635
- slot_type: str,
636
- slot: str,
637
- slot_schema: Schema,
638
- var_fields: dict[str, FieldAttr],
639
- categoricals: dict[str, dict[str, FieldAttr]],
640
- slots: dict[str, DataFrameCurator],
641
- ) -> None:
642
- """Assigns var_fields and categoricals for multimodal data curators."""
643
- if modality is not None:
644
- # Makes sure that all tables are present
645
- var_fields[modality] = None
646
- categoricals[modality] = {}
647
-
648
- if slot_type == "var":
649
- var_field = parse_cat_dtype(slot_schema.itype, is_itype=True)["field"]
650
- if modality is None:
651
- # This should rarely/never be used since tables should have different var fields
652
- var_fields[slot] = var_field # pragma: no cover
653
- else:
654
- # Note that this is NOT nested since the nested key is always "var"
655
- var_fields[modality] = var_field
656
- else:
657
- obs_fields = slots[slot]._cat_manager.categoricals
658
- if modality is None:
659
- categoricals[slot] = obs_fields
660
- else:
661
- # Note that this is NOT nested since the nested key is always "obs"
662
- categoricals[modality] = obs_fields
663
-
664
-
665
- class MuDataCurator(SlotsCurator):
666
- # the example in the docstring is tested in test_curators_quickstart_example
667
- """Curator for `MuData`.
668
-
669
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
670
-
671
- .. versionadded:: 1.3.0
672
-
673
- Args:
674
- dataset: The MuData-like object to validate & annotate.
675
- schema: A `Schema` object that defines the validation constraints.
676
-
677
- Example::
678
-
679
- import lamindb as ln
680
- import bionty as bt
681
-
682
- # define the global obs schema
683
- obs_schema = ln.Schema(
684
- name="mudata_papalexi21_subset_obs_schema",
685
- features=[
686
- ln.Feature(name="perturbation", dtype="cat[ULabel[Perturbation]]").save(),
687
- ln.Feature(name="replicate", dtype="cat[ULabel[Replicate]]").save(),
688
- ],
689
- ).save()
690
-
691
- # define the ['rna'].obs schema
692
- obs_schema_rna = ln.Schema(
693
- name="mudata_papalexi21_subset_rna_obs_schema",
694
- features=[
695
- ln.Feature(name="nCount_RNA", dtype=int).save(),
696
- ln.Feature(name="nFeature_RNA", dtype=int).save(),
697
- ln.Feature(name="percent.mito", dtype=float).save(),
698
- ],
699
- coerce_dtype=True,
700
- ).save()
701
-
702
- # define the ['hto'].obs schema
703
- obs_schema_hto = ln.Schema(
704
- name="mudata_papalexi21_subset_hto_obs_schema",
705
- features=[
706
- ln.Feature(name="nCount_HTO", dtype=int).save(),
707
- ln.Feature(name="nFeature_HTO", dtype=int).save(),
708
- ln.Feature(name="technique", dtype=bt.ExperimentalFactor).save(),
709
- ],
710
- coerce_dtype=True,
711
- ).save()
712
-
713
- # define ['rna'].var schema
714
- var_schema_rna = ln.Schema(
715
- name="mudata_papalexi21_subset_rna_var_schema",
716
- itype=bt.Gene.symbol,
717
- dtype=float,
718
- ).save()
719
-
720
- # define composite schema
721
- mudata_schema = ln.Schema(
722
- name="mudata_papalexi21_subset_mudata_schema",
723
- otype="MuData",
724
- components={
725
- "obs": obs_schema,
726
- "rna:obs": obs_schema_rna,
727
- "hto:obs": obs_schema_hto,
728
- "rna:var": var_schema_rna,
729
- },
730
- ).save()
731
-
732
- # curate a MuData
733
- mdata = ln.core.datasets.mudata_papalexi21_subset()
734
- bt.settings.organism = "human" # set the organism
735
- curator = ln.curators.MuDataCurator(mdata, mudata_schema)
736
- artifact = curator.save_artifact(key="example_datasets/mudata_papalexi21_subset.h5mu")
737
- assert artifact.schema == mudata_schema
738
- """
739
-
740
- def __init__(
741
- self,
742
- dataset: MuData | Artifact,
743
- schema: Schema,
744
- ) -> None:
745
- super().__init__(dataset=dataset, schema=schema)
746
- if not data_is_mudata(self._dataset):
747
- raise InvalidArgument("dataset must be MuData-like.")
748
- if schema.otype != "MuData":
749
- raise InvalidArgument("Schema otype must be 'MuData'.")
750
-
751
- for slot, slot_schema in schema.slots.items():
752
- # Assign to _slots
753
- if ":" in slot:
754
- modality, modality_slot = slot.split(":")
755
- schema_dataset = self._dataset.__getitem__(modality)
756
- else:
757
- modality, modality_slot = None, slot
758
- schema_dataset = self._dataset
759
- self._slots[slot] = DataFrameCurator(
760
- (
761
- getattr(schema_dataset, modality_slot).T
762
- if modality_slot == "var"
763
- else getattr(schema_dataset, modality_slot)
764
- ),
765
- slot_schema,
766
- )
767
- _assign_var_fields_categoricals_multimodal(
768
- modality=modality,
769
- slot_type=modality_slot,
770
- slot=slot,
771
- slot_schema=slot_schema,
772
- var_fields=self._var_fields,
773
- categoricals=self._categoricals,
774
- slots=self._slots,
775
- )
776
-
777
- # for consistency with BaseCatManager
778
- self._columns_field = self._var_fields
779
-
780
-
781
- class SpatialDataCurator(SlotsCurator):
782
- # the example in the docstring is tested in test_curators_quickstart_example
783
- """Curator for `SpatialData`.
784
-
785
- See also :class:`~lamindb.Curator` and :class:`~lamindb.Schema`.
786
-
787
- .. versionadded:: 1.3.0
788
-
789
- Args:
790
- dataset: The SpatialData-like object to validate & annotate.
791
- schema: A `Schema` object that defines the validation constraints.
792
-
793
- Example::
794
-
795
- import lamindb as ln
796
- import bionty as bt
797
-
798
- # define sample schema
799
- sample_schema = ln.Schema(
800
- name="blobs_sample_level_metadata",
801
- features=[
802
- ln.Feature(name="assay", dtype=bt.ExperimentalFactor).save(),
803
- ln.Feature(name="disease", dtype=bt.Disease).save(),
804
- ln.Feature(name="development_stage", dtype=bt.DevelopmentalStage).save(),
805
- ],
806
- coerce_dtype=True
807
- ).save()
808
-
809
- # define table obs schema
810
- blobs_obs_schema = ln.Schema(
811
- name="blobs_obs_level_metadata",
812
- features=[
813
- ln.Feature(name="sample_region", dtype="str").save(),
814
- ],
815
- coerce_dtype=True
816
- ).save()
817
-
818
- # define table var schema
819
- blobs_var_schema = ln.Schema(
820
- name="blobs_var_schema",
821
- itype=bt.Gene.ensembl_gene_id,
822
- dtype=int
823
- ).save()
824
-
825
- # define composite schema
826
- spatialdata_schema = ln.Schema(
827
- name="blobs_spatialdata_schema",
828
- otype="SpatialData",
829
- components={
830
- "sample": sample_schema,
831
- "table:obs": blobs_obs_schema,
832
- "table:var": blobs_var_schema,
833
- }).save()
834
-
835
- # curate a SpatialData
836
- spatialdata = ln.core.datasets.spatialdata_blobs()
837
- curator = ln.curators.SpatialDataCurator(spatialdata, spatialdata_schema)
838
- try:
839
- curator.validate()
840
- except ln.errors.ValidationError as error:
841
- print(error)
842
-
843
- # validate again (must pass now) and save artifact
844
- artifact = curator.save_artifact(key="example_datasets/spatialdata1.zarr")
845
- assert artifact.schema == spatialdata_schema
846
- """
847
-
848
- def __init__(
849
- self,
850
- dataset: SpatialData | Artifact,
851
- schema: Schema,
852
- *,
853
- sample_metadata_key: str | None = "sample",
854
- ) -> None:
855
- super().__init__(dataset=dataset, schema=schema)
856
- if not data_is_spatialdata(self._dataset):
857
- raise InvalidArgument("dataset must be SpatialData-like.")
858
- if schema.otype != "SpatialData":
859
- raise InvalidArgument("Schema otype must be 'SpatialData'.")
860
-
861
- for slot, slot_schema in schema.slots.items():
862
- # Assign to _slots
863
- if ":" in slot:
864
- table_key, table_slot = slot.split(":")
865
- schema_dataset = self._dataset.tables.__getitem__(table_key)
866
- # sample metadata (does not have a `:` separator)
867
- else:
868
- table_key = None
869
- table_slot = slot
870
- schema_dataset = self._dataset.get_attrs(
871
- key=sample_metadata_key, return_as="df", flatten=True
872
- )
873
-
874
- self._slots[slot] = DataFrameCurator(
875
- (
876
- getattr(schema_dataset, table_slot).T
877
- if table_slot == "var"
878
- else (
879
- getattr(schema_dataset, table_slot)
880
- if table_slot != sample_metadata_key
881
- else schema_dataset
882
- ) # just take the schema_dataset if it's the sample metadata key
883
- ),
884
- slot_schema,
885
- )
886
-
887
- _assign_var_fields_categoricals_multimodal(
888
- modality=table_key,
889
- slot_type=table_slot,
890
- slot=slot,
891
- slot_schema=slot_schema,
892
- var_fields=self._var_fields,
893
- categoricals=self._categoricals,
894
- slots=self._slots,
895
- )
896
-
897
- # for consistency with BaseCatManager
898
- self._columns_field = self._var_fields
899
-
900
-
901
- class CatManager:
902
- """Manage categoricals by updating registries.
903
-
904
- This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
905
-
906
- If you find non-validated values, you have several options:
907
-
908
- - new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.add_new_from`
909
- - non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.non_validated` and addressed manually
910
- """
911
-
912
- def __init__(self, *, dataset, categoricals, sources, organism, columns_field=None):
913
- # the below is shared with Curator
914
- self._artifact: Artifact = None # pass the dataset as an artifact
915
- self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
916
- if isinstance(self._dataset, Artifact):
917
- self._artifact = self._dataset
918
- if self._artifact.otype in {"DataFrame", "AnnData"}:
919
- self._dataset = self._dataset.load()
920
- self._is_validated: bool = False
921
- # shared until here
922
- self._categoricals = categoricals or {}
923
- self._non_validated = None
924
- self._sources = sources or {}
925
- self._columns_field = columns_field
926
- self._validate_category_error_messages: str = ""
927
- # make sure to only fetch organism once at the beginning
928
- if organism:
929
- self._organism = organism
930
- else:
931
- fields = list(self._categoricals.values()) + [columns_field]
932
- organisms = {get_organism_kwargs(field).get("organism") for field in fields}
933
- self._organism = organisms.pop() if len(organisms) > 0 else None
934
-
935
- @property
936
- def non_validated(self) -> dict[str, list[str]]:
937
- """Return the non-validated features and labels."""
938
- if self._non_validated is None:
939
- raise ValidationError("Please run validate() first!")
940
- return self._non_validated
941
-
942
- @property
943
- def categoricals(self) -> dict:
944
- """Return the columns fields to validate against."""
945
- return self._categoricals
946
-
947
- def _replace_synonyms(
948
- self, key: str, syn_mapper: dict, values: pd.Series | pd.Index
949
- ):
950
- # replace the values in df
951
- std_values = values.map(lambda unstd_val: syn_mapper.get(unstd_val, unstd_val))
952
- # remove the standardized values from self.non_validated
953
- non_validated = [i for i in self.non_validated[key] if i not in syn_mapper]
954
- if len(non_validated) == 0:
955
- self._non_validated.pop(key, None) # type: ignore
956
- else:
957
- self._non_validated[key] = non_validated # type: ignore
958
- # logging
959
- n = len(syn_mapper)
960
- if n > 0:
961
- syn_mapper_print = _format_values(
962
- [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
963
- )
964
- s = "s" if n > 1 else ""
965
- logger.success(
966
- f'standardized {n} synonym{s} in "{key}": {colors.green(syn_mapper_print)}'
967
- )
968
- return std_values
969
-
970
- def validate(self) -> bool:
971
- """Validate dataset.
972
-
973
- This method also registers the validated records in the current instance.
974
-
975
- Returns:
976
- The boolean `True` if the dataset is validated. Otherwise, a string with the error message.
977
- """
978
- pass # pragma: no cover
979
-
980
- def standardize(self, key: str) -> None:
981
- """Replace synonyms with standardized values.
982
-
983
- Inplace modification of the dataset.
984
-
985
- Args:
986
- key: The name of the column to standardize.
987
-
988
- Returns:
989
- None
990
- """
991
- pass # pragma: no cover
992
-
993
- @doc_args(SAVE_ARTIFACT_DOCSTRING)
994
- def save_artifact(
995
- self,
996
- *,
997
- key: str | None = None,
998
- description: str | None = None,
999
- revises: Artifact | None = None,
1000
- run: Run | None = None,
1001
- ) -> Artifact:
1002
- """{}""" # noqa: D415
1003
- # Make sure all labels are saved in the current instance
1004
- if not self._is_validated:
1005
- self.validate() # returns True or False
1006
- if not self._is_validated: # need to raise error manually
1007
- raise ValidationError("Dataset does not validate. Please curate.")
1008
-
1009
- self._artifact = save_artifact( # type: ignore
1010
- self._dataset,
1011
- key=key,
1012
- description=description,
1013
- fields=self.categoricals,
1014
- index_field=self._columns_field,
1015
- artifact=self._artifact,
1016
- revises=revises,
1017
- run=run,
1018
- schema=None,
1019
- organism=self._organism,
1020
- )
1021
-
1022
- return self._artifact
1023
-
1024
-
1025
- class DataFrameCatManager(CatManager):
1026
- """Categorical manager for `DataFrame`."""
1027
-
1028
- def __init__(
1029
- self,
1030
- df: pd.DataFrame | Artifact,
1031
- columns: FieldAttr = Feature.name,
1032
- categoricals: dict[str, FieldAttr] | None = None,
1033
- verbosity: str = "hint",
1034
- organism: str | None = None,
1035
- sources: dict[str, Record] | None = None,
1036
- ) -> None:
1037
- if organism is not None and not isinstance(organism, str):
1038
- raise ValueError("organism must be a string such as 'human' or 'mouse'!")
1039
-
1040
- settings.verbosity = verbosity
1041
- self._non_validated = None
1042
- super().__init__(
1043
- dataset=df,
1044
- columns_field=columns,
1045
- organism=organism,
1046
- categoricals=categoricals,
1047
- sources=sources,
1048
- )
1049
- self._save_columns()
1050
-
1051
- def lookup(self, public: bool = False) -> CatLookup:
1052
- """Lookup categories.
1053
-
1054
- Args:
1055
- public: If "public", the lookup is performed on the public reference.
1056
- """
1057
- return CatLookup(
1058
- categoricals=self._categoricals,
1059
- slots={"columns": self._columns_field},
1060
- public=public,
1061
- organism=self._organism,
1062
- sources=self._sources,
1063
- )
1064
-
1065
- def _save_columns(self, validated_only: bool = True) -> None:
1066
- """Save column name records."""
1067
- # Always save features specified as the fields keys
1068
- update_registry(
1069
- values=list(self.categoricals.keys()),
1070
- field=self._columns_field,
1071
- key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
1072
- validated_only=False,
1073
- source=self._sources.get("columns"),
1074
- )
1075
-
1076
- # Save the rest of the columns based on validated_only
1077
- additional_columns = set(self._dataset.keys()) - set(self.categoricals.keys())
1078
- if additional_columns:
1079
- update_registry(
1080
- values=list(additional_columns),
1081
- field=self._columns_field,
1082
- key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
1083
- validated_only=validated_only,
1084
- df=self._dataset, # Get the Feature type from df
1085
- source=self._sources.get("columns"),
1086
- )
1087
-
1088
- @deprecated(new_name="is run by default")
1089
- def add_new_from_columns(self, organism: str | None = None, **kwargs):
1090
- pass # pragma: no cover
1091
-
1092
- def validate(self) -> bool:
1093
- """Validate variables and categorical observations.
1094
-
1095
- This method also registers the validated records in the current instance:
1096
- - from public sources
1097
-
1098
- Args:
1099
- organism: The organism name.
1100
-
1101
- Returns:
1102
- Whether the DataFrame is validated.
1103
- """
1104
- # add all validated records to the current instance
1105
- self._update_registry_all()
1106
- self._validate_category_error_messages = "" # reset the error messages
1107
- self._is_validated, self._non_validated = validate_categories_in_df( # type: ignore
1108
- self._dataset,
1109
- fields=self.categoricals,
1110
- sources=self._sources,
1111
- curator=self,
1112
- organism=self._organism,
1113
- )
1114
- return self._is_validated
1115
-
1116
- def standardize(self, key: str) -> None:
1117
- """Replace synonyms with standardized values.
1118
-
1119
- Modifies the input dataset inplace.
1120
-
1121
- Args:
1122
- key: The key referencing the column in the DataFrame to standardize.
1123
- """
1124
- if self._artifact is not None:
1125
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1126
- # list is needed to avoid RuntimeError: dictionary changed size during iteration
1127
- avail_keys = list(self.non_validated.keys())
1128
- if len(avail_keys) == 0:
1129
- logger.warning("values are already standardized")
1130
- return
1131
-
1132
- if key == "all":
1133
- for k in avail_keys:
1134
- if k in self._categoricals: # needed to exclude var_index
1135
- syn_mapper = standardize_categories(
1136
- self.non_validated[k],
1137
- field=self._categoricals[k],
1138
- source=self._sources.get(k),
1139
- )
1140
- self._dataset[k] = self._replace_synonyms(
1141
- k, syn_mapper, self._dataset[k]
1142
- )
1143
- else:
1144
- if key not in avail_keys:
1145
- if key in self._categoricals:
1146
- logger.warning(f"No non-standardized values found for {key!r}")
1147
- else:
1148
- raise KeyError(
1149
- f"{key!r} is not a valid key, available keys are: {_format_values(avail_keys)}!"
1150
- )
1151
- else:
1152
- if key in self._categoricals: # needed to exclude var_index
1153
- syn_mapper = standardize_categories(
1154
- self.non_validated[key],
1155
- field=self._categoricals[key],
1156
- source=self._sources.get(key),
1157
- organism=self._organism,
1158
- )
1159
- self._dataset[key] = self._replace_synonyms(
1160
- key, syn_mapper, self._dataset[key]
1161
- )
1162
-
1163
- def _update_registry_all(self, validated_only: bool = True, **kwargs):
1164
- """Save labels for all features."""
1165
- for name in self.categoricals.keys():
1166
- self._update_registry(name, validated_only=validated_only, **kwargs)
1167
-
1168
- def _update_registry(
1169
- self, categorical: str, validated_only: bool = True, **kwargs
1170
- ) -> None:
1171
- if categorical == "all":
1172
- self._update_registry_all(validated_only=validated_only, **kwargs)
1173
- else:
1174
- if categorical not in self.categoricals:
1175
- raise ValidationError(
1176
- f"Feature {categorical} is not part of the fields!"
1177
- )
1178
- update_registry(
1179
- values=_flatten_unique(self._dataset[categorical]),
1180
- field=self.categoricals[categorical],
1181
- key=categorical,
1182
- validated_only=validated_only,
1183
- source=self._sources.get(categorical),
1184
- organism=self._organism,
1185
- )
1186
- # adding new records removes them from non_validated
1187
- if not validated_only and self._non_validated:
1188
- self._non_validated.pop(categorical, None) # type: ignore
1189
-
1190
- def add_new_from(self, key: str, **kwargs):
1191
- """Add validated & new categories.
1192
-
1193
- Args:
1194
- key: The key referencing the slot in the DataFrame from which to draw terms.
1195
- organism: The organism name.
1196
- **kwargs: Additional keyword arguments to pass to create new records
1197
- """
1198
- if len(kwargs) > 0 and key == "all":
1199
- raise ValueError("Cannot pass additional arguments to 'all' key!")
1200
- self._update_registry(key, validated_only=False, **kwargs)
1201
-
1202
- def clean_up_failed_runs(self):
1203
- """Clean up previous failed runs that don't save any outputs."""
1204
- from lamindb.core._context import context
1205
-
1206
- if context.run is not None:
1207
- Run.filter(transform=context.run.transform, output_artifacts=None).exclude(
1208
- uid=context.run.uid
1209
- ).delete()
1210
-
1211
-
1212
- class AnnDataCatManager(CatManager):
1213
- """Categorical manager for `AnnData`."""
1214
-
1215
- def __init__(
1216
- self,
1217
- data: ad.AnnData | Artifact,
1218
- var_index: FieldAttr | None = None,
1219
- categoricals: dict[str, FieldAttr] | None = None,
1220
- obs_columns: FieldAttr = Feature.name,
1221
- verbosity: str = "hint",
1222
- organism: str | None = None,
1223
- sources: dict[str, Record] | None = None,
1224
- ) -> None:
1225
- if isinstance(var_index, str):
1226
- raise TypeError(
1227
- "var_index parameter has to be a field, e.g. Gene.ensembl_gene_id"
1228
- )
1229
-
1230
- if not data_is_anndata(data):
1231
- raise TypeError("data has to be an AnnData object")
1232
-
1233
- if "symbol" in str(var_index):
1234
- logger.warning(
1235
- "indexing datasets with gene symbols can be problematic: https://docs.lamin.ai/faq/symbol-mapping"
1236
- )
1237
-
1238
- self._obs_fields = categoricals or {}
1239
- self._var_field = var_index
1240
- self._sources = sources or {}
1241
- super().__init__(
1242
- dataset=data,
1243
- categoricals=categoricals,
1244
- sources=self._sources,
1245
- organism=organism,
1246
- columns_field=var_index,
1247
- )
1248
- self._adata = self._dataset
1249
- self._obs_df_curator = DataFrameCatManager(
1250
- df=self._adata.obs,
1251
- categoricals=self.categoricals,
1252
- columns=obs_columns,
1253
- verbosity=verbosity,
1254
- organism=None,
1255
- sources=self._sources,
1256
- )
1257
-
1258
- @property
1259
- def var_index(self) -> FieldAttr:
1260
- """Return the registry field to validate variables index against."""
1261
- return self._var_field
1262
-
1263
- @property
1264
- def categoricals(self) -> dict:
1265
- """Return the obs fields to validate against."""
1266
- return self._obs_fields
1267
-
1268
- def lookup(self, public: bool = False) -> CatLookup:
1269
- """Lookup categories.
1270
-
1271
- Args:
1272
- public: If "public", the lookup is performed on the public reference.
1273
- """
1274
- return CatLookup(
1275
- categoricals=self._obs_fields,
1276
- slots={"columns": self._columns_field, "var_index": self._var_field},
1277
- public=public,
1278
- organism=self._organism,
1279
- sources=self._sources,
1280
- )
1281
-
1282
- def _save_from_var_index(
1283
- self,
1284
- validated_only: bool = True,
1285
- ):
1286
- """Save variable records."""
1287
- if self.var_index is not None:
1288
- update_registry(
1289
- values=list(self._adata.var.index),
1290
- field=self.var_index,
1291
- key="var_index",
1292
- validated_only=validated_only,
1293
- organism=self._organism,
1294
- source=self._sources.get("var_index"),
1295
- )
1296
-
1297
- def add_new_from(self, key: str, **kwargs):
1298
- """Add validated & new categories.
1299
-
1300
- Args:
1301
- key: The key referencing the slot in the DataFrame from which to draw terms.
1302
- organism: The organism name.
1303
- **kwargs: Additional keyword arguments to pass to create new records
1304
- """
1305
- self._obs_df_curator.add_new_from(key, **kwargs)
1306
-
1307
- def add_new_from_var_index(self, **kwargs):
1308
- """Update variable records.
1309
-
1310
- Args:
1311
- organism: The organism name.
1312
- **kwargs: Additional keyword arguments to pass to create new records.
1313
- """
1314
- self._save_from_var_index(validated_only=False, **kwargs)
1315
-
1316
- def validate(self) -> bool:
1317
- """Validate categories.
1318
-
1319
- This method also registers the validated records in the current instance.
1320
-
1321
- Args:
1322
- organism: The organism name.
1323
-
1324
- Returns:
1325
- Whether the AnnData object is validated.
1326
- """
1327
- self._validate_category_error_messages = "" # reset the error messages
1328
-
1329
- # add all validated records to the current instance
1330
- self._save_from_var_index(validated_only=True)
1331
- if self.var_index is not None:
1332
- validated_var, non_validated_var = validate_categories(
1333
- self._adata.var.index,
1334
- field=self._var_field,
1335
- key="var_index",
1336
- source=self._sources.get("var_index"),
1337
- hint_print=".add_new_from_var_index()",
1338
- organism=self._organism, # type: ignore
1339
- )
1340
- else:
1341
- validated_var = True
1342
- non_validated_var = []
1343
- validated_obs = self._obs_df_curator.validate()
1344
- self._non_validated = self._obs_df_curator._non_validated # type: ignore
1345
- if len(non_validated_var) > 0:
1346
- self._non_validated["var_index"] = non_validated_var # type: ignore
1347
- self._is_validated = validated_var and validated_obs
1348
- return self._is_validated
1349
-
1350
- def standardize(self, key: str):
1351
- """Replace synonyms with standardized values.
1352
-
1353
- Args:
1354
- key: The key referencing the slot in `adata.obs` from which to draw terms. Same as the key in `categoricals`.
1355
-
1356
- - If "var_index", standardize the var.index.
1357
- - If "all", standardize all obs columns and var.index.
1358
-
1359
- Inplace modification of the dataset.
1360
- """
1361
- if self._artifact is not None:
1362
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1363
- if key in self._adata.obs.columns or key == "all":
1364
- # standardize obs columns
1365
- self._obs_df_curator.standardize(key)
1366
- # in addition to the obs columns, standardize the var.index
1367
- if key == "var_index" or key == "all":
1368
- syn_mapper = standardize_categories(
1369
- self._adata.var.index,
1370
- field=self.var_index,
1371
- source=self._sources.get("var_index"),
1372
- organism=self._organism,
1373
- )
1374
- if "var_index" in self._non_validated: # type: ignore
1375
- self._adata.var.index = self._replace_synonyms(
1376
- "var_index", syn_mapper, self._adata.var.index
1377
- )
1378
-
1379
-
1380
- class MuDataCatManager(CatManager):
1381
- """Categorical manager for `MuData`."""
1382
-
1383
- def __init__(
1384
- self,
1385
- mdata: MuData | Artifact,
1386
- var_index: dict[str, FieldAttr] | None = None,
1387
- categoricals: dict[str, FieldAttr] | None = None,
1388
- verbosity: str = "hint",
1389
- organism: str | None = None,
1390
- sources: dict[str, Record] | None = None,
1391
- ) -> None:
1392
- super().__init__(
1393
- dataset=mdata,
1394
- categoricals={},
1395
- sources=sources,
1396
- organism=organism,
1397
- )
1398
- self._columns_field = (
1399
- var_index or {}
1400
- ) # this is for consistency with BaseCatManager
1401
- self._var_fields = var_index or {}
1402
- self._verify_modality(self._var_fields.keys())
1403
- self._obs_fields = self._parse_categoricals(categoricals or {})
1404
- self._modalities = set(self._var_fields.keys()) | set(self._obs_fields.keys())
1405
- self._verbosity = verbosity
1406
- self._obs_df_curator = None
1407
- if "obs" in self._modalities:
1408
- self._obs_df_curator = DataFrameCatManager(
1409
- df=self._dataset.obs,
1410
- columns=Feature.name,
1411
- categoricals=self._obs_fields.get("obs", {}),
1412
- verbosity=verbosity,
1413
- sources=self._sources.get("obs"),
1414
- organism=organism,
1415
- )
1416
- self._mod_adata_curators = {
1417
- modality: AnnDataCatManager(
1418
- data=self._dataset[modality],
1419
- var_index=var_index.get(modality),
1420
- categoricals=self._obs_fields.get(modality),
1421
- verbosity=verbosity,
1422
- sources=self._sources.get(modality),
1423
- organism=organism,
1424
- )
1425
- for modality in self._modalities
1426
- if modality != "obs"
1427
- }
1428
- self._non_validated = None
1429
-
1430
- @property
1431
- def var_index(self) -> FieldAttr:
1432
- """Return the registry field to validate variables index against."""
1433
- return self._var_fields
1434
-
1435
- @property
1436
- def categoricals(self) -> dict:
1437
- """Return the obs fields to validate against."""
1438
- return self._obs_fields
1439
-
1440
- @property
1441
- def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
1442
- """Return the non-validated features and labels."""
1443
- if self._non_validated is None:
1444
- raise ValidationError("Please run validate() first!")
1445
- return self._non_validated
1446
-
1447
- def _verify_modality(self, modalities: Iterable[str]):
1448
- """Verify the modality exists."""
1449
- for modality in modalities:
1450
- if modality not in self._dataset.mod.keys():
1451
- raise ValidationError(f"modality '{modality}' does not exist!")
1452
-
1453
- def _parse_categoricals(self, categoricals: dict[str, FieldAttr]) -> dict:
1454
- """Parse the categorical fields."""
1455
- prefixes = {f"{k}:" for k in self._dataset.mod.keys()}
1456
- obs_fields: dict[str, dict[str, FieldAttr]] = {}
1457
- for k, v in categoricals.items():
1458
- if k not in self._dataset.obs.columns:
1459
- raise ValidationError(f"column '{k}' does not exist in mdata.obs!")
1460
- if any(k.startswith(prefix) for prefix in prefixes):
1461
- modality, col = k.split(":")[0], k.split(":")[1]
1462
- if modality not in obs_fields.keys():
1463
- obs_fields[modality] = {}
1464
- obs_fields[modality][col] = v
1465
- else:
1466
- if "obs" not in obs_fields.keys():
1467
- obs_fields["obs"] = {}
1468
- obs_fields["obs"][k] = v
1469
- return obs_fields
1470
-
1471
- def lookup(self, public: bool = False) -> CatLookup:
1472
- """Lookup categories.
1473
-
1474
- Args:
1475
- public: Perform lookup on public source ontologies.
1476
- """
1477
- obs_fields = {}
1478
- for mod, fields in self._obs_fields.items():
1479
- for k, v in fields.items():
1480
- if k == "obs":
1481
- obs_fields[k] = v
1482
- else:
1483
- obs_fields[f"{mod}:{k}"] = v
1484
- return CatLookup(
1485
- categoricals=obs_fields,
1486
- slots={
1487
- **{f"{k}_var_index": v for k, v in self._var_fields.items()},
1488
- },
1489
- public=public,
1490
- organism=self._organism,
1491
- sources=self._sources,
1492
- )
1493
-
1494
- @deprecated(new_name="is run by default")
1495
- def add_new_from_columns(
1496
- self,
1497
- modality: str,
1498
- column_names: list[str] | None = None,
1499
- **kwargs,
1500
- ):
1501
- pass # pragma: no cover
1502
-
1503
- def add_new_from_var_index(self, modality: str, **kwargs):
1504
- """Update variable records.
1505
-
1506
- Args:
1507
- modality: The modality name.
1508
- organism: The organism name.
1509
- **kwargs: Additional keyword arguments to pass to create new records.
1510
- """
1511
- self._mod_adata_curators[modality].add_new_from_var_index(**kwargs)
1512
-
1513
- def _update_registry_all(self):
1514
- """Update all registries."""
1515
- if self._obs_df_curator is not None:
1516
- self._obs_df_curator._update_registry_all(validated_only=True)
1517
- for _, adata_curator in self._mod_adata_curators.items():
1518
- adata_curator._obs_df_curator._update_registry_all(validated_only=True)
1519
-
1520
- def add_new_from(
1521
- self,
1522
- key: str,
1523
- modality: str | None = None,
1524
- **kwargs,
1525
- ):
1526
- """Add validated & new categories.
1527
-
1528
- Args:
1529
- key: The key referencing the slot in the DataFrame.
1530
- modality: The modality name.
1531
- organism: The organism name.
1532
- **kwargs: Additional keyword arguments to pass to create new records.
1533
- """
1534
- if len(kwargs) > 0 and key == "all":
1535
- raise ValueError("Cannot pass additional arguments to 'all' key!")
1536
- modality = modality or "obs"
1537
- if modality in self._mod_adata_curators:
1538
- adata_curator = self._mod_adata_curators[modality]
1539
- adata_curator.add_new_from(key=key, **kwargs)
1540
- if modality == "obs":
1541
- self._obs_df_curator.add_new_from(key=key, **kwargs)
1542
-
1543
- def validate(self) -> bool:
1544
- """Validate categories."""
1545
- # add all validated records to the current instance
1546
- self._update_registry_all()
1547
- self._non_validated = {} # type: ignore
1548
-
1549
- obs_validated = True
1550
- if "obs" in self._modalities:
1551
- logger.info('validating categoricals in "obs"...')
1552
- obs_validated &= self._obs_df_curator.validate()
1553
- self._non_validated["obs"] = self._obs_df_curator.non_validated # type: ignore
1554
- logger.print("")
1555
-
1556
- mods_validated = True
1557
- for modality, adata_curator in self._mod_adata_curators.items():
1558
- logger.info(f'validating categoricals in modality "{modality}"...')
1559
- mods_validated &= adata_curator.validate()
1560
- if len(adata_curator.non_validated) > 0:
1561
- self._non_validated[modality] = adata_curator.non_validated # type: ignore
1562
- logger.print("")
1563
-
1564
- self._is_validated = obs_validated & mods_validated
1565
- return self._is_validated
1566
-
1567
- def standardize(self, key: str, modality: str | None = None):
1568
- """Replace synonyms with standardized values.
1569
-
1570
- Args:
1571
- key: The key referencing the slot in the `MuData`.
1572
- modality: The modality name.
1573
-
1574
- Inplace modification of the dataset.
1575
- """
1576
- if self._artifact is not None:
1577
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1578
- modality = modality or "obs"
1579
- if modality in self._mod_adata_curators:
1580
- adata_curator = self._mod_adata_curators[modality]
1581
- adata_curator.standardize(key=key)
1582
- if modality == "obs":
1583
- self._obs_df_curator.standardize(key=key)
1584
-
1585
-
1586
- def _maybe_curation_keys_not_present(nonval_keys: list[str], name: str):
1587
- if (n := len(nonval_keys)) > 0:
1588
- s = "s" if n > 1 else ""
1589
- are = "are" if n > 1 else "is"
1590
- raise ValidationError(
1591
- f"key{s} passed to {name} {are} not present: {colors.yellow(_format_values(nonval_keys))}"
1592
- )
1593
-
1594
-
1595
- class SpatialDataCatManager(CatManager):
1596
- """Categorical manager for `SpatialData`."""
1597
-
1598
- def __init__(
1599
- self,
1600
- sdata: Any,
1601
- var_index: dict[str, FieldAttr],
1602
- categoricals: dict[str, dict[str, FieldAttr]] | None = None,
1603
- verbosity: str = "hint",
1604
- organism: str | None = None,
1605
- sources: dict[str, dict[str, Record]] | None = None,
1606
- *,
1607
- sample_metadata_key: str | None = "sample",
1608
- ) -> None:
1609
- super().__init__(
1610
- dataset=sdata,
1611
- categoricals={},
1612
- sources=sources,
1613
- organism=organism,
1614
- )
1615
- if isinstance(sdata, Artifact):
1616
- self._sdata = sdata.load()
1617
- else:
1618
- self._sdata = self._dataset
1619
- self._sample_metadata_key = sample_metadata_key
1620
- self._write_path = None
1621
- self._var_fields = var_index
1622
- self._verify_accessor_exists(self._var_fields.keys())
1623
- self._categoricals = categoricals
1624
- self._table_keys = set(self._var_fields.keys()) | set(
1625
- self._categoricals.keys() - {self._sample_metadata_key}
1626
- )
1627
- self._verbosity = verbosity
1628
- self._sample_df_curator = None
1629
- if self._sample_metadata_key is not None:
1630
- self._sample_metadata = self._sdata.get_attrs(
1631
- key=self._sample_metadata_key, return_as="df", flatten=True
1632
- )
1633
- self._is_validated = False
1634
-
1635
- # Check validity of keys in categoricals
1636
- nonval_keys = []
1637
- for accessor, accessor_categoricals in self._categoricals.items():
1638
- if (
1639
- accessor == self._sample_metadata_key
1640
- and self._sample_metadata is not None
1641
- ):
1642
- for key in accessor_categoricals.keys():
1643
- if key not in self._sample_metadata.columns:
1644
- nonval_keys.append(key)
1645
- else:
1646
- for key in accessor_categoricals.keys():
1647
- if key not in self._sdata[accessor].obs.columns:
1648
- nonval_keys.append(key)
1649
-
1650
- _maybe_curation_keys_not_present(nonval_keys, "categoricals")
1651
-
1652
- # check validity of keys in sources
1653
- nonval_keys = []
1654
- for accessor, accessor_sources in self._sources.items():
1655
- if (
1656
- accessor == self._sample_metadata_key
1657
- and self._sample_metadata is not None
1658
- ):
1659
- columns = self._sample_metadata.columns
1660
- elif accessor != self._sample_metadata_key:
1661
- columns = self._sdata[accessor].obs.columns
1662
- else:
1663
- continue
1664
- for key in accessor_sources:
1665
- if key not in columns:
1666
- nonval_keys.append(key)
1667
- _maybe_curation_keys_not_present(nonval_keys, "sources")
1668
-
1669
- # Set up sample level metadata and table Curator objects
1670
- if (
1671
- self._sample_metadata_key is not None
1672
- and self._sample_metadata_key in self._categoricals
1673
- ):
1674
- self._sample_df_curator = DataFrameCatManager(
1675
- df=self._sample_metadata,
1676
- columns=Feature.name,
1677
- categoricals=self._categoricals.get(self._sample_metadata_key, {}),
1678
- verbosity=verbosity,
1679
- sources=self._sources.get(self._sample_metadata_key),
1680
- organism=organism,
1681
- )
1682
- self._table_adata_curators = {
1683
- table: AnnDataCatManager(
1684
- data=self._sdata[table],
1685
- var_index=var_index.get(table),
1686
- categoricals=self._categoricals.get(table),
1687
- verbosity=verbosity,
1688
- sources=self._sources.get(table),
1689
- organism=organism,
1690
- )
1691
- for table in self._table_keys
1692
- }
1693
-
1694
- self._non_validated = None
1695
-
1696
- @property
1697
- def var_index(self) -> FieldAttr:
1698
- """Return the registry fields to validate variables indices against."""
1699
- return self._var_fields
1700
-
1701
- @property
1702
- def categoricals(self) -> dict[str, dict[str, FieldAttr]]:
1703
- """Return the categorical keys and fields to validate against."""
1704
- return self._categoricals
1705
-
1706
- @property
1707
- def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
1708
- """Return the non-validated features and labels."""
1709
- if self._non_validated is None:
1710
- raise ValidationError("Please run validate() first!")
1711
- return self._non_validated
1712
-
1713
- def _verify_accessor_exists(self, accessors: Iterable[str]) -> None:
1714
- """Verify that the accessors exist (either a valid table or in attrs)."""
1715
- for acc in accessors:
1716
- is_present = False
1717
- try:
1718
- self._sdata.get_attrs(key=acc)
1719
- is_present = True
1720
- except KeyError:
1721
- if acc in self._sdata.tables.keys():
1722
- is_present = True
1723
- if not is_present:
1724
- raise ValidationError(f"Accessor '{acc}' does not exist!")
1725
-
1726
- def lookup(self, public: bool = False) -> CatLookup:
1727
- """Look up categories.
1728
-
1729
- Args:
1730
- public: Whether the lookup is performed on the public reference.
1731
- """
1732
- cat_values_dict = list(self.categoricals.values())[0]
1733
- return CatLookup(
1734
- categoricals=cat_values_dict,
1735
- slots={"accessors": cat_values_dict.keys()},
1736
- public=public,
1737
- organism=self._organism,
1738
- sources=self._sources,
1739
- )
1740
-
1741
- def _update_registry_all(self) -> None:
1742
- """Saves labels of all features for sample and table metadata."""
1743
- if self._sample_df_curator is not None:
1744
- self._sample_df_curator._update_registry_all(
1745
- validated_only=True,
1746
- )
1747
- for _, adata_curator in self._table_adata_curators.items():
1748
- adata_curator._obs_df_curator._update_registry_all(
1749
- validated_only=True,
1750
- )
1751
-
1752
- def add_new_from_var_index(self, table: str, **kwargs) -> None:
1753
- """Save new values from ``.var.index`` of table.
1754
-
1755
- Args:
1756
- table: The table key.
1757
- organism: The organism name.
1758
- **kwargs: Additional keyword arguments to pass to create new records.
1759
- """
1760
- if self._non_validated is None:
1761
- raise ValidationError("Run .validate() first.")
1762
- self._table_adata_curators[table].add_new_from_var_index(**kwargs)
1763
- if table in self.non_validated.keys():
1764
- if "var_index" in self._non_validated[table]:
1765
- self._non_validated[table].pop("var_index")
1766
-
1767
- if len(self.non_validated[table].values()) == 0:
1768
- self.non_validated.pop(table)
1769
-
1770
- def add_new_from(
1771
- self,
1772
- key: str,
1773
- accessor: str | None = None,
1774
- **kwargs,
1775
- ) -> None:
1776
- """Save new values of categorical from sample level metadata or table.
1777
-
1778
- Args:
1779
- key: The key referencing the slot in the DataFrame.
1780
- accessor: The accessor key such as 'sample' or 'table x'.
1781
- organism: The organism name.
1782
- **kwargs: Additional keyword arguments to pass to create new records.
1783
- """
1784
- if self._non_validated is None:
1785
- raise ValidationError("Run .validate() first.")
1786
-
1787
- if len(kwargs) > 0 and key == "all":
1788
- raise ValueError("Cannot pass additional arguments to 'all' key!")
1789
-
1790
- if accessor not in self.categoricals:
1791
- raise ValueError(
1792
- f"Accessor {accessor} is not in 'categoricals'. Include it when creating the SpatialDataCatManager."
1793
- )
1794
-
1795
- if accessor in self._table_adata_curators:
1796
- adata_curator = self._table_adata_curators[accessor]
1797
- adata_curator.add_new_from(key=key, **kwargs)
1798
- if accessor == self._sample_metadata_key:
1799
- self._sample_df_curator.add_new_from(key=key, **kwargs)
1800
-
1801
- if accessor in self.non_validated.keys():
1802
- if len(self.non_validated[accessor].values()) == 0:
1803
- self.non_validated.pop(accessor)
1804
-
1805
- def standardize(self, key: str, accessor: str | None = None) -> None:
1806
- """Replace synonyms with canonical values.
1807
-
1808
- Modifies the dataset inplace.
1809
-
1810
- Args:
1811
- key: The key referencing the slot in the table or sample metadata.
1812
- accessor: The accessor key such as 'sample_key' or 'table_key'.
1813
- """
1814
- if len(self.non_validated) == 0:
1815
- logger.warning("values are already standardized")
1816
- return
1817
- if self._artifact is not None:
1818
- raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1819
-
1820
- if accessor == self._sample_metadata_key:
1821
- if key not in self._sample_metadata.columns:
1822
- raise ValueError(f"key '{key}' not present in '{accessor}'!")
1823
- else:
1824
- if (
1825
- key == "var_index" and self._sdata.tables[accessor].var.index is None
1826
- ) or (
1827
- key != "var_index"
1828
- and key not in self._sdata.tables[accessor].obs.columns
1829
- ):
1830
- raise ValueError(f"key '{key}' not present in '{accessor}'!")
1831
-
1832
- if accessor in self._table_adata_curators.keys():
1833
- adata_curator = self._table_adata_curators[accessor]
1834
- adata_curator.standardize(key)
1835
- if accessor == self._sample_metadata_key:
1836
- self._sample_df_curator.standardize(key)
1837
-
1838
- if len(self.non_validated[accessor].values()) == 0:
1839
- self.non_validated.pop(accessor)
1840
-
1841
- def validate(self) -> bool:
1842
- """Validate variables and categorical observations.
1843
-
1844
- This method also registers the validated records in the current instance:
1845
- - from public sources
1846
-
1847
- Args:
1848
- organism: The organism name.
1849
-
1850
- Returns:
1851
- Whether the SpatialData object is validated.
1852
- """
1853
- # add all validated records to the current instance
1854
- self._update_registry_all()
1855
-
1856
- self._non_validated = {} # type: ignore
1857
-
1858
- sample_validated = True
1859
- if self._sample_df_curator:
1860
- logger.info(f"validating categoricals of '{self._sample_metadata_key}' ...")
1861
- sample_validated &= self._sample_df_curator.validate()
1862
- if len(self._sample_df_curator.non_validated) > 0:
1863
- self._non_validated["sample"] = self._sample_df_curator.non_validated # type: ignore
1864
- logger.print("")
1865
-
1866
- mods_validated = True
1867
- for table, adata_curator in self._table_adata_curators.items():
1868
- logger.info(f"validating categoricals of table '{table}' ...")
1869
- mods_validated &= adata_curator.validate()
1870
- if len(adata_curator.non_validated) > 0:
1871
- self._non_validated[table] = adata_curator.non_validated # type: ignore
1872
- logger.print("")
1873
-
1874
- self._is_validated = sample_validated & mods_validated
1875
- return self._is_validated
1876
-
1877
- def save_artifact(
1878
- self,
1879
- *,
1880
- key: str | None = None,
1881
- description: str | None = None,
1882
- revises: Artifact | None = None,
1883
- run: Run | None = None,
1884
- ) -> Artifact:
1885
- """Save the validated SpatialData store and metadata.
1886
-
1887
- Args:
1888
- description: A description of the dataset.
1889
- key: A path-like key to reference artifact in default storage,
1890
- e.g., `"myartifact.zarr"`. Artifacts with the same key form a version family.
1891
- revises: Previous version of the artifact. Triggers a revision.
1892
- run: The run that creates the artifact.
1893
-
1894
- Returns:
1895
- A saved artifact record.
1896
- """
1897
- if not self._is_validated:
1898
- self.validate()
1899
- if not self._is_validated:
1900
- raise ValidationError("Dataset does not validate. Please curate.")
1901
-
1902
- return save_artifact(
1903
- self._sdata,
1904
- description=description,
1905
- fields=self.categoricals,
1906
- index_field=self.var_index,
1907
- key=key,
1908
- artifact=self._artifact,
1909
- revises=revises,
1910
- run=run,
1911
- schema=None,
1912
- organism=self._organism,
1913
- sample_metadata_key=self._sample_metadata_key,
1914
- )
1915
-
1916
-
1917
- class TiledbsomaCatManager(CatManager):
1918
- """Categorical manager for `tiledbsoma.Experiment`."""
1919
-
1920
- def __init__(
1921
- self,
1922
- experiment_uri: UPathStr | Artifact,
1923
- var_index: dict[str, tuple[str, FieldAttr]],
1924
- categoricals: dict[str, FieldAttr] | None = None,
1925
- obs_columns: FieldAttr = Feature.name,
1926
- organism: str | None = None,
1927
- sources: dict[str, Record] | None = None,
1928
- ):
1929
- self._obs_fields = categoricals or {}
1930
- self._var_fields = var_index
1931
- self._columns_field = obs_columns
1932
- if isinstance(experiment_uri, Artifact):
1933
- self._dataset = experiment_uri.path
1934
- self._artifact = experiment_uri
1935
- else:
1936
- self._dataset = UPath(experiment_uri)
1937
- self._artifact = None
1938
- self._organism = organism
1939
- self._sources = sources or {}
1940
-
1941
- self._is_validated: bool | None = False
1942
- self._non_validated_values: dict[str, list] | None = None
1943
- self._validated_values: dict[str, list] = {}
1944
- # filled by _check_save_keys
1945
- self._n_obs: int | None = None
1946
- self._valid_obs_keys: list[str] | None = None
1947
- self._obs_pa_schema: pa.lib.Schema | None = (
1948
- None # this is needed to create the obs feature set
1949
- )
1950
- self._valid_var_keys: list[str] | None = None
1951
- self._var_fields_flat: dict[str, FieldAttr] | None = None
1952
- self._check_save_keys()
1953
-
1954
- # check that the provided keys in var_index and categoricals are available in the store
1955
- # and save features
1956
- def _check_save_keys(self):
1957
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
1958
-
1959
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
1960
- experiment_obs = experiment.obs
1961
- self._n_obs = len(experiment_obs)
1962
- self._obs_pa_schema = experiment_obs.schema
1963
- valid_obs_keys = [
1964
- k for k in self._obs_pa_schema.names if k != "soma_joinid"
1965
- ]
1966
- self._valid_obs_keys = valid_obs_keys
1967
-
1968
- valid_var_keys = []
1969
- ms_list = []
1970
- for ms in experiment.ms.keys():
1971
- ms_list.append(ms)
1972
- var_ms = experiment.ms[ms].var
1973
- valid_var_keys += [
1974
- f"{ms}__{k}" for k in var_ms.keys() if k != "soma_joinid"
1975
- ]
1976
- self._valid_var_keys = valid_var_keys
1977
-
1978
- # check validity of keys in categoricals
1979
- nonval_keys = []
1980
- for obs_key in self._obs_fields.keys():
1981
- if obs_key not in valid_obs_keys:
1982
- nonval_keys.append(obs_key)
1983
- _maybe_curation_keys_not_present(nonval_keys, "categoricals")
1984
-
1985
- # check validity of keys in var_index
1986
- self._var_fields_flat = {}
1987
- nonval_keys = []
1988
- for ms_key in self._var_fields.keys():
1989
- var_key, var_field = self._var_fields[ms_key]
1990
- var_key_flat = f"{ms_key}__{var_key}"
1991
- if var_key_flat not in valid_var_keys:
1992
- nonval_keys.append(f"({ms_key}, {var_key})")
1993
- else:
1994
- self._var_fields_flat[var_key_flat] = var_field
1995
- _maybe_curation_keys_not_present(nonval_keys, "var_index")
1996
-
1997
- # check validity of keys in sources
1998
- valid_arg_keys = valid_obs_keys + valid_var_keys + ["columns"]
1999
- nonval_keys = []
2000
- for arg_key in self._sources.keys():
2001
- if arg_key not in valid_arg_keys:
2002
- nonval_keys.append(arg_key)
2003
- _maybe_curation_keys_not_present(nonval_keys, "sources")
2004
-
2005
- # register obs columns' names
2006
- register_columns = list(self._obs_fields.keys())
2007
- update_registry(
2008
- values=register_columns,
2009
- field=self._columns_field,
2010
- key="columns",
2011
- validated_only=False,
2012
- organism=self._organism,
2013
- source=self._sources.get("columns"),
2014
- )
2015
- additional_columns = [k for k in valid_obs_keys if k not in register_columns]
2016
- # no need to register with validated_only=True if columns are features
2017
- if (
2018
- len(additional_columns) > 0
2019
- and self._columns_field.field.model is not Feature
2020
- ):
2021
- update_registry(
2022
- values=additional_columns,
2023
- field=self._columns_field,
2024
- key="columns",
2025
- validated_only=True,
2026
- organism=self._organism,
2027
- source=self._sources.get("columns"),
2028
- )
2029
-
2030
- def validate(self):
2031
- """Validate categories."""
2032
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
2033
-
2034
- validated = True
2035
- self._non_validated_values = {}
2036
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
2037
- for ms, (key, field) in self._var_fields.items():
2038
- var_ms = experiment.ms[ms].var
2039
- var_ms_key = f"{ms}__{key}"
2040
- # it was already validated and cached
2041
- if var_ms_key in self._validated_values:
2042
- continue
2043
- var_ms_values = (
2044
- var_ms.read(column_names=[key]).concat()[key].to_pylist()
2045
- )
2046
- update_registry(
2047
- values=var_ms_values,
2048
- field=field,
2049
- key=var_ms_key,
2050
- validated_only=True,
2051
- organism=self._organism,
2052
- source=self._sources.get(var_ms_key),
2053
- )
2054
- _, non_val = validate_categories(
2055
- values=var_ms_values,
2056
- field=field,
2057
- key=var_ms_key,
2058
- organism=self._organism,
2059
- source=self._sources.get(var_ms_key),
2060
- )
2061
- if len(non_val) > 0:
2062
- validated = False
2063
- self._non_validated_values[var_ms_key] = non_val
2064
- else:
2065
- self._validated_values[var_ms_key] = var_ms_values
2066
-
2067
- obs = experiment.obs
2068
- for key, field in self._obs_fields.items():
2069
- # already validated and cached
2070
- if key in self._validated_values:
2071
- continue
2072
- values = pa.compute.unique(
2073
- obs.read(column_names=[key]).concat()[key]
2074
- ).to_pylist()
2075
- update_registry(
2076
- values=values,
2077
- field=field,
2078
- key=key,
2079
- validated_only=True,
2080
- organism=self._organism,
2081
- source=self._sources.get(key),
2082
- )
2083
- _, non_val = validate_categories(
2084
- values=values,
2085
- field=field,
2086
- key=key,
2087
- organism=self._organism,
2088
- source=self._sources.get(key),
2089
- )
2090
- if len(non_val) > 0:
2091
- validated = False
2092
- self._non_validated_values[key] = non_val
2093
- else:
2094
- self._validated_values[key] = values
2095
- self._is_validated = validated
2096
- return self._is_validated
2097
-
2098
- def _non_validated_values_field(self, key: str) -> tuple[list, FieldAttr]:
2099
- assert self._non_validated_values is not None # noqa: S101
2100
-
2101
- if key in self._valid_obs_keys:
2102
- field = self._obs_fields[key]
2103
- elif key in self._valid_var_keys:
2104
- ms = key.partition("__")[0]
2105
- field = self._var_fields[ms][1]
2106
- else:
2107
- raise KeyError(f"key {key} is invalid!")
2108
- values = self._non_validated_values.get(key, [])
2109
- return values, field
2110
-
2111
- def add_new_from(self, key: str, **kwargs) -> None:
2112
- """Add validated & new categories.
2113
-
2114
- Args:
2115
- key: The key referencing the slot in the `tiledbsoma` store.
2116
- It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
2117
- or a column name in `.obs`.
2118
- """
2119
- if self._non_validated_values is None:
2120
- raise ValidationError("Run .validate() first.")
2121
- if key == "all":
2122
- keys = list(self._non_validated_values.keys())
2123
- else:
2124
- avail_keys = list(
2125
- chain(self._non_validated_values.keys(), self._validated_values.keys())
2126
- )
2127
- if key not in avail_keys:
2128
- raise KeyError(
2129
- f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
2130
- )
2131
- keys = [key]
2132
- for k in keys:
2133
- values, field = self._non_validated_values_field(k)
2134
- if len(values) == 0:
2135
- continue
2136
- update_registry(
2137
- values=values,
2138
- field=field,
2139
- key=k,
2140
- validated_only=False,
2141
- organism=self._organism,
2142
- source=self._sources.get(k),
2143
- **kwargs,
2144
- )
2145
- # update non-validated values list but keep the key there
2146
- # it will be removed by .validate()
2147
- if k in self._non_validated_values:
2148
- self._non_validated_values[k] = []
2149
-
2150
- @property
2151
- def non_validated(self) -> dict[str, list]:
2152
- """Return the non-validated features and labels."""
2153
- non_val = {k: v for k, v in self._non_validated_values.items() if v != []}
2154
- return non_val
2155
-
2156
- @property
2157
- def var_index(self) -> dict[str, FieldAttr]:
2158
- """Return the registry fields with flattened keys to validate variables indices against."""
2159
- return self._var_fields_flat
2160
-
2161
- @property
2162
- def categoricals(self) -> dict[str, FieldAttr]:
2163
- """Return the obs fields to validate against."""
2164
- return self._obs_fields
2165
-
2166
- def lookup(self, public: bool = False) -> CatLookup:
2167
- """Lookup categories.
2168
-
2169
- Args:
2170
- public: If "public", the lookup is performed on the public reference.
2171
- """
2172
- return CatLookup(
2173
- categoricals=self._obs_fields,
2174
- slots={"columns": self._columns_field, **self._var_fields_flat},
2175
- public=public,
2176
- organism=self._organism,
2177
- sources=self._sources,
2178
- )
2179
-
2180
- def standardize(self, key: str):
2181
- """Replace synonyms with standardized values.
2182
-
2183
- Modifies the dataset inplace.
2184
-
2185
- Args:
2186
- key: The key referencing the slot in the `tiledbsoma` store.
2187
- It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
2188
- or a column name in `.obs`.
2189
- """
2190
- if len(self.non_validated) == 0:
2191
- logger.warning("values are already standardized")
2192
- return
2193
- avail_keys = list(self._non_validated_values.keys())
2194
- if key == "all":
2195
- keys = avail_keys
2196
- else:
2197
- if key not in avail_keys:
2198
- raise KeyError(
2199
- f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
2200
- )
2201
- keys = [key]
2202
-
2203
- for k in keys:
2204
- values, field = self._non_validated_values_field(k)
2205
- if len(values) == 0:
2206
- continue
2207
- if k in self._valid_var_keys:
2208
- ms, _, slot_key = k.partition("__")
2209
- slot = lambda experiment: experiment.ms[ms].var # noqa: B023
2210
- else:
2211
- slot = lambda experiment: experiment.obs
2212
- slot_key = k
2213
- syn_mapper = standardize_categories(
2214
- values=values,
2215
- field=field,
2216
- source=self._sources.get(k),
2217
- organism=self._organism,
2218
- )
2219
- if (n_syn_mapper := len(syn_mapper)) == 0:
2220
- continue
2221
-
2222
- from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
2223
-
2224
- with _open_tiledbsoma(self._dataset, mode="r") as experiment:
2225
- value_filter = f"{slot_key} in {list(syn_mapper.keys())}"
2226
- table = slot(experiment).read(value_filter=value_filter).concat()
2227
-
2228
- if len(table) == 0:
2229
- continue
2230
-
2231
- df = table.to_pandas()
2232
- # map values
2233
- df[slot_key] = df[slot_key].map(
2234
- lambda val: syn_mapper.get(val, val) # noqa
2235
- )
2236
- # write the mapped values
2237
- with _open_tiledbsoma(self._dataset, mode="w") as experiment:
2238
- slot(experiment).write(pa.Table.from_pandas(df, schema=table.schema))
2239
- # update non_validated dict
2240
- non_val_k = [
2241
- nv for nv in self._non_validated_values[k] if nv not in syn_mapper
2242
- ]
2243
- self._non_validated_values[k] = non_val_k
2244
-
2245
- syn_mapper_print = _format_values(
2246
- [f'"{m_k}" → "{m_v}"' for m_k, m_v in syn_mapper.items()], sep=""
2247
- )
2248
- s = "s" if n_syn_mapper > 1 else ""
2249
- logger.success(
2250
- f'standardized {n_syn_mapper} synonym{s} in "{k}": {colors.green(syn_mapper_print)}'
2251
- )
2252
-
2253
- def save_artifact(
2254
- self,
2255
- *,
2256
- key: str | None = None,
2257
- description: str | None = None,
2258
- revises: Artifact | None = None,
2259
- run: Run | None = None,
2260
- ) -> Artifact:
2261
- """Save the validated `tiledbsoma` store and metadata.
2262
-
2263
- Args:
2264
- description: A description of the ``tiledbsoma`` store.
2265
- key: A path-like key to reference artifact in default storage,
2266
- e.g., `"myfolder/mystore.tiledbsoma"`. Artifacts with the same key form a version family.
2267
- revises: Previous version of the artifact. Triggers a revision.
2268
- run: The run that creates the artifact.
2269
-
2270
- Returns:
2271
- A saved artifact record.
2272
- """
2273
- if not self._is_validated:
2274
- self.validate()
2275
- if not self._is_validated:
2276
- raise ValidationError("Dataset does not validate. Please curate.")
2277
-
2278
- if self._artifact is None:
2279
- artifact = Artifact(
2280
- self._dataset,
2281
- description=description,
2282
- key=key,
2283
- revises=revises,
2284
- run=run,
2285
- )
2286
- artifact.n_observations = self._n_obs
2287
- artifact.otype = "tiledbsoma"
2288
- artifact.save()
2289
- else:
2290
- artifact = self._artifact
2291
-
2292
- feature_sets = {}
2293
- if len(self._obs_fields) > 0:
2294
- empty_dict = {field.name: [] for field in self._obs_pa_schema} # type: ignore
2295
- mock_df = pa.Table.from_pydict(
2296
- empty_dict, schema=self._obs_pa_schema
2297
- ).to_pandas()
2298
- # in parallel to https://github.com/laminlabs/lamindb/blob/2a1709990b5736b480c6de49c0ada47fafc8b18d/lamindb/core/_feature_manager.py#L549-L554
2299
- feature_sets["obs"] = Schema.from_df(
2300
- df=mock_df,
2301
- field=self._columns_field,
2302
- mute=True,
2303
- organism=self._organism,
2304
- )
2305
- for ms in self._var_fields:
2306
- var_key, var_field = self._var_fields[ms]
2307
- feature_sets[f"{ms}__var"] = Schema.from_values(
2308
- values=self._validated_values[f"{ms}__{var_key}"],
2309
- field=var_field,
2310
- organism=self._organism,
2311
- raise_validation_error=False,
2312
- )
2313
- artifact._staged_feature_sets = feature_sets
2314
-
2315
- feature_ref_is_name = _ref_is_name(self._columns_field)
2316
- features = Feature.lookup().dict()
2317
- for key, field in self._obs_fields.items():
2318
- feature = features.get(key)
2319
- registry = field.field.model
2320
- labels = registry.from_values(
2321
- values=self._validated_values[key],
2322
- field=field,
2323
- organism=self._organism,
2324
- )
2325
- if len(labels) == 0:
2326
- continue
2327
- if hasattr(registry, "_name_field"):
2328
- label_ref_is_name = field.field.name == registry._name_field
2329
- add_labels(
2330
- artifact,
2331
- records=labels,
2332
- feature=feature,
2333
- feature_ref_is_name=feature_ref_is_name,
2334
- label_ref_is_name=label_ref_is_name,
2335
- from_curator=True,
2336
- )
2337
-
2338
- return artifact.save()
2339
-
2340
-
2341
- class CellxGeneAnnDataCatManager(AnnDataCatManager):
2342
- """Categorical manager for `AnnData` respecting the CELLxGENE schema.
2343
-
2344
- This will be superceded by a schema-based curation flow.
2345
- """
2346
-
2347
- cxg_categoricals_defaults = {
2348
- "cell_type": "unknown",
2349
- "development_stage": "unknown",
2350
- "disease": "normal",
2351
- "donor_id": "unknown",
2352
- "self_reported_ethnicity": "unknown",
2353
- "sex": "unknown",
2354
- "suspension_type": "cell",
2355
- "tissue_type": "tissue",
2356
- }
2357
-
2358
- def __init__(
2359
- self,
2360
- adata: ad.AnnData,
2361
- categoricals: dict[str, FieldAttr] | None = None,
2362
- organism: Literal["human", "mouse"] = "human",
2363
- *,
2364
- schema_version: Literal["4.0.0", "5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
2365
- defaults: dict[str, str] = None,
2366
- extra_sources: dict[str, Record] = None,
2367
- verbosity: str = "hint",
2368
- ) -> None:
2369
- """CELLxGENE schema curator.
2370
-
2371
- Args:
2372
- adata: Path to or AnnData object to curate against the CELLxGENE schema.
2373
- categoricals: A dictionary mapping ``.obs.columns`` to a registry field.
2374
- The CELLxGENE Curator maps against the required CELLxGENE fields by default.
2375
- organism: The organism name. CELLxGENE restricts it to 'human' and 'mouse'.
2376
- schema_version: The CELLxGENE schema version to curate against.
2377
- defaults: Default values that are set if columns or column values are missing.
2378
- extra_sources: A dictionary mapping ``.obs.columns`` to Source records.
2379
- These extra sources are joined with the CELLxGENE fixed sources.
2380
- Use this parameter when subclassing.
2381
- verbosity: The verbosity level.
2382
- """
2383
- import bionty as bt
2384
-
2385
- from ._cellxgene_schemas import (
2386
- _add_defaults_to_obs,
2387
- _create_sources,
2388
- _init_categoricals_additional_values,
2389
- _restrict_obs_fields,
2390
- )
2391
-
2392
- # Add defaults first to ensure that we fetch valid sources
2393
- if defaults:
2394
- _add_defaults_to_obs(adata.obs, defaults)
2395
-
2396
- # Filter categoricals based on what's present in adata
2397
- if categoricals is None:
2398
- categoricals = self._get_cxg_categoricals()
2399
- categoricals = _restrict_obs_fields(adata.obs, categoricals)
2400
-
2401
- # Configure sources
2402
- sources = _create_sources(categoricals, schema_version, organism)
2403
- self.schema_version = schema_version
2404
- self.schema_reference = f"https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/{schema_version}/schema.md"
2405
- # These sources are not a part of the cellxgene schema but rather passed through.
2406
- # This is useful when other Curators extend the CELLxGENE curator
2407
- if extra_sources:
2408
- sources = sources | extra_sources
2409
-
2410
- _init_categoricals_additional_values()
2411
-
2412
- super().__init__(
2413
- data=adata,
2414
- var_index=bt.Gene.ensembl_gene_id,
2415
- categoricals=categoricals,
2416
- verbosity=verbosity,
2417
- organism=organism,
2418
- sources=sources,
2419
- )
2420
-
2421
- @classmethod
2422
- @deprecated(new_name="cxg_categoricals_defaults")
2423
- def _get_categoricals_defaults(cls) -> dict[str, str]:
2424
- return cls.cxg_categoricals_defaults
2425
-
2426
- @classmethod
2427
- def _get_cxg_categoricals(cls) -> dict[str, FieldAttr]:
2428
- """Returns the CELLxGENE schema mapped fields."""
2429
- from ._cellxgene_schemas import _get_cxg_categoricals
2430
-
2431
- return _get_cxg_categoricals()
2432
-
2433
- def validate(self) -> bool:
2434
- """Validates the AnnData object against most cellxgene requirements."""
2435
- from ._cellxgene_schemas import RESERVED_NAMES
2436
-
2437
- # Verify that all required obs columns are present
2438
- required_columns = list(self.cxg_categoricals_defaults.keys()) + ["donor_id"]
2439
- missing_obs_fields = [
2440
- name
2441
- for name in required_columns
2442
- if name not in self._adata.obs.columns
2443
- and f"{name}_ontology_term_id" not in self._adata.obs.columns
2444
- ]
2445
- if len(missing_obs_fields) > 0:
2446
- logger.error(
2447
- f"missing required obs columns {_format_values(missing_obs_fields)}\n"
2448
- " → consider initializing a Curate object with `defaults=cxg.CellxGeneAnnDataCatManager.cxg_categoricals_defaults` to automatically add these columns with default values"
2449
- )
2450
- return False
2451
-
2452
- # Verify that no cellxgene reserved names are present
2453
- matched_columns = [
2454
- column for column in self._adata.obs.columns if column in RESERVED_NAMES
2455
- ]
2456
- if len(matched_columns) > 0:
2457
- raise ValueError(
2458
- f"AnnData object must not contain obs columns {matched_columns} which are"
2459
- " reserved from previous schema versions."
2460
- )
2461
-
2462
- return super().validate()
2463
-
2464
- def to_cellxgene_anndata(
2465
- self, is_primary_data: bool, title: str | None = None
2466
- ) -> ad.AnnData:
2467
- """Converts the AnnData object to the cellxgene-schema input format.
2468
-
2469
- cellxgene expects the obs fields to be {entity}_ontology_id fields and has many further requirements which are
2470
- documented here: https://github.com/chanzuckerberg/single-cell-curation/tree/main/schema.
2471
- This function checks for most but not all requirements of the CELLxGENE schema.
2472
- If you want to ensure that it fully adheres to the CELLxGENE schema, run `cellxgene-schema` on the AnnData object.
2473
-
2474
- Args:
2475
- is_primary_data: Whether the measured data is primary data or not.
2476
- title: Title of the AnnData object. Commonly the name of the publication.
2477
-
2478
- Returns:
2479
- An AnnData object which adheres to the cellxgene-schema.
2480
- """
2481
-
2482
- def _convert_name_to_ontology_id(values: pd.Series, field: FieldAttr):
2483
- """Converts a column that stores a name into a column that stores the ontology id.
2484
-
2485
- cellxgene expects the obs columns to be {entity}_ontology_id columns and disallows {entity} columns.
2486
- """
2487
- field_name = field.field.name
2488
- assert field_name == "name" # noqa: S101
2489
- cols = ["name", "ontology_id"]
2490
- registry = field.field.model
2491
-
2492
- if hasattr(registry, "ontology_id"):
2493
- validated_records = registry.filter(**{f"{field_name}__in": values})
2494
- mapper = (
2495
- pd.DataFrame(validated_records.values_list(*cols))
2496
- .set_index(0)
2497
- .to_dict()[1]
2498
- )
2499
- return values.map(mapper)
2500
-
2501
- # Create a copy since we modify the AnnData object extensively
2502
- adata_cxg = self._adata.copy()
2503
-
2504
- # cellxgene requires an embedding
2505
- embedding_pattern = r"^[a-zA-Z][a-zA-Z0-9_.-]*$"
2506
- exclude_key = "spatial"
2507
- matching_keys = [
2508
- key
2509
- for key in adata_cxg.obsm.keys()
2510
- if re.match(embedding_pattern, key) and key != exclude_key
2511
- ]
2512
- if len(matching_keys) == 0:
2513
- raise ValueError(
2514
- "Unable to find an embedding key. Please calculate an embedding."
2515
- )
2516
-
2517
- # convert name column to ontology_term_id column
2518
- for column in adata_cxg.obs.columns:
2519
- if column in self.categoricals and not column.endswith("_ontology_term_id"):
2520
- mapped_column = _convert_name_to_ontology_id(
2521
- adata_cxg.obs[column], field=self.categoricals.get(column)
2522
- )
2523
- if mapped_column is not None:
2524
- adata_cxg.obs[f"{column}_ontology_term_id"] = mapped_column
2525
-
2526
- # drop the name columns for ontologies. cellxgene does not allow them.
2527
- drop_columns = [
2528
- i
2529
- for i in adata_cxg.obs.columns
2530
- if f"{i}_ontology_term_id" in adata_cxg.obs.columns
2531
- ]
2532
- adata_cxg.obs.drop(columns=drop_columns, inplace=True)
2533
-
2534
- # Add cellxgene metadata to AnnData object
2535
- if "is_primary_data" not in adata_cxg.obs.columns:
2536
- adata_cxg.obs["is_primary_data"] = is_primary_data
2537
- if "feature_is_filtered" not in adata_cxg.var.columns:
2538
- logger.warn(
2539
- "column 'feature_is_filtered' not present in var. Setting to default"
2540
- " value of False."
2541
- )
2542
- adata_cxg.var["feature_is_filtered"] = False
2543
- if title is None:
2544
- raise ValueError("please pass a title!")
2545
- else:
2546
- adata_cxg.uns["title"] = title
2547
- adata_cxg.uns["cxg_lamin_schema_reference"] = self.schema_reference
2548
- adata_cxg.uns["cxg_lamin_schema_version"] = self.schema_version
2549
-
2550
- return adata_cxg
2551
-
2552
-
2553
- class ValueUnit:
2554
- """Base class for handling value-unit combinations."""
2555
-
2556
- @staticmethod
2557
- def parse_value_unit(value: str, is_dose: bool = True) -> tuple[str, str] | None:
2558
- """Parse a string containing a value and unit into a tuple."""
2559
- if not isinstance(value, str) or not value.strip():
2560
- return None
2561
-
2562
- value = str(value).strip()
2563
- match = re.match(r"^(\d*\.?\d{0,1})\s*([a-zA-ZμµΜ]+)$", value)
2564
-
2565
- if not match:
2566
- raise ValueError(
2567
- f"Invalid format: {value}. Expected format: number with max 1 decimal place + unit"
2568
- )
2569
-
2570
- number, unit = match.groups()
2571
- formatted_number = f"{float(number):.1f}"
2572
-
2573
- if is_dose:
2574
- standardized_unit = DoseHandler.standardize_unit(unit)
2575
- if not DoseHandler.validate_unit(standardized_unit):
2576
- raise ValueError(
2577
- f"Invalid dose unit: {unit}. Must be convertible to one of: nM, μM, mM, M"
2578
- )
2579
- else:
2580
- standardized_unit = TimeHandler.standardize_unit(unit)
2581
- if not TimeHandler.validate_unit(standardized_unit):
2582
- raise ValueError(
2583
- f"Invalid time unit: {unit}. Must be convertible to one of: h, m, s, d, y"
2584
- )
2585
-
2586
- return formatted_number, standardized_unit
2587
-
2588
-
2589
- class DoseHandler:
2590
- """Handler for dose-related operations."""
2591
-
2592
- VALID_UNITS = {"nM", "μM", "µM", "mM", "M"}
2593
- UNIT_MAP = {
2594
- "nm": "nM",
2595
- "NM": "nM",
2596
- "um": "μM",
2597
- "UM": "μM",
2598
- "μm": "μM",
2599
- "μM": "μM",
2600
- "µm": "μM",
2601
- "µM": "μM",
2602
- "mm": "mM",
2603
- "MM": "mM",
2604
- "m": "M",
2605
- "M": "M",
2606
- }
2607
-
2608
- @classmethod
2609
- def validate_unit(cls, unit: str) -> bool:
2610
- """Validate if the dose unit is acceptable."""
2611
- return unit in cls.VALID_UNITS
2612
-
2613
- @classmethod
2614
- def standardize_unit(cls, unit: str) -> str:
2615
- """Standardize dose unit to standard formats."""
2616
- return cls.UNIT_MAP.get(unit, unit)
2617
-
2618
- @classmethod
2619
- def validate_values(cls, values: pd.Series) -> list[str]:
2620
- """Validate pert_dose values with strict case checking."""
2621
- errors = []
2622
-
2623
- for idx, value in values.items():
2624
- if pd.isna(value):
2625
- continue
2626
-
2627
- if isinstance(value, (int, float)):
2628
- errors.append(
2629
- f"Row {idx} - Missing unit for dose: {value}. Must include a unit (nM, μM, mM, M)"
2630
- )
2631
- continue
2632
-
2633
- try:
2634
- ValueUnit.parse_value_unit(value, is_dose=True)
2635
- except ValueError as e:
2636
- errors.append(f"Row {idx} - {str(e)}")
2637
-
2638
- return errors
2639
-
2640
-
2641
- class TimeHandler:
2642
- """Handler for time-related operations."""
2643
-
2644
- VALID_UNITS = {"h", "m", "s", "d", "y"}
2645
-
2646
- @classmethod
2647
- def validate_unit(cls, unit: str) -> bool:
2648
- """Validate if the time unit is acceptable."""
2649
- return unit == unit.lower() and unit in cls.VALID_UNITS
2650
-
2651
- @classmethod
2652
- def standardize_unit(cls, unit: str) -> str:
2653
- """Standardize time unit to standard formats."""
2654
- if unit.startswith("hr"):
2655
- return "h"
2656
- elif unit.startswith("min"):
2657
- return "m"
2658
- elif unit.startswith("sec"):
2659
- return "s"
2660
- return unit[0].lower()
2661
-
2662
- @classmethod
2663
- def validate_values(cls, values: pd.Series) -> list[str]:
2664
- """Validate pert_time values."""
2665
- errors = []
2666
-
2667
- for idx, value in values.items():
2668
- if pd.isna(value):
2669
- continue
2670
-
2671
- if isinstance(value, (int, float)):
2672
- errors.append(
2673
- f"Row {idx} - Missing unit for time: {value}. Must include a unit (h, m, s, d, y)"
2674
- )
2675
- continue
2676
-
2677
- try:
2678
- ValueUnit.parse_value_unit(value, is_dose=False)
2679
- except ValueError as e:
2680
- errors.append(f"Row {idx} - {str(e)}")
2681
-
2682
- return errors
2683
-
2684
-
2685
- class PertAnnDataCatManager(CellxGeneAnnDataCatManager):
2686
- """Categorical manager for `AnnData` to manage perturbations."""
2687
-
2688
- PERT_COLUMNS = {"compound", "genetic", "biologic", "physical"}
2689
-
2690
- def __init__(
2691
- self,
2692
- adata: ad.AnnData,
2693
- organism: Literal["human", "mouse"] = "human",
2694
- pert_dose: bool = True,
2695
- pert_time: bool = True,
2696
- *,
2697
- cxg_schema_version: Literal["5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
2698
- verbosity: str = "hint",
2699
- ):
2700
- """Initialize the curator with configuration and validation settings."""
2701
- self._pert_time = pert_time
2702
- self._pert_dose = pert_dose
2703
-
2704
- self._validate_initial_data(adata)
2705
- categoricals, categoricals_defaults = self._configure_categoricals(adata)
2706
-
2707
- super().__init__(
2708
- adata=adata,
2709
- categoricals=categoricals,
2710
- defaults=categoricals_defaults,
2711
- organism=organism,
2712
- extra_sources=self._configure_sources(adata),
2713
- schema_version=cxg_schema_version,
2714
- verbosity=verbosity,
2715
- )
2716
-
2717
- def _configure_categoricals(self, adata: ad.AnnData):
2718
- """Set up default configuration values."""
2719
- import bionty as bt
2720
- import wetlab as wl
2721
-
2722
- categoricals = CellxGeneAnnDataCatManager._get_cxg_categoricals() | {
2723
- k: v
2724
- for k, v in {
2725
- "cell_line": bt.CellLine.name,
2726
- "pert_target": wl.PerturbationTarget.name,
2727
- "pert_genetic": wl.GeneticPerturbation.name,
2728
- "pert_compound": wl.Compound.name,
2729
- "pert_biologic": wl.Biologic.name,
2730
- "pert_physical": wl.EnvironmentalPerturbation.name,
2731
- }.items()
2732
- if k in adata.obs.columns
2733
- }
2734
- # if "donor_id" in categoricals:
2735
- # categoricals["donor_id"] = Donor.name
2736
-
2737
- categoricals_defaults = CellxGeneAnnDataCatManager.cxg_categoricals_defaults | {
2738
- "cell_line": "unknown",
2739
- "pert_target": "unknown",
2740
- }
2741
-
2742
- return categoricals, categoricals_defaults
2743
-
2744
- def _configure_sources(self, adata: ad.AnnData):
2745
- """Set up data sources."""
2746
- import bionty as bt
2747
- import wetlab as wl
2748
-
2749
- sources = {}
2750
- # # do not yet specify cell_line source
2751
- # if "cell_line" in adata.obs.columns:
2752
- # sources["cell_line"] = bt.Source.filter(
2753
- # entity="bionty.CellLine", name="depmap"
2754
- # ).first()
2755
- if "pert_compound" in adata.obs.columns:
2756
- with logger.mute():
2757
- chebi_source = bt.Source.filter(
2758
- entity="wetlab.Compound", name="chebi"
2759
- ).first()
2760
- if not chebi_source:
2761
- wl.Compound.add_source(
2762
- bt.Source.filter(entity="Drug", name="chebi").first()
2763
- )
2764
-
2765
- sources["pert_compound"] = bt.Source.filter(
2766
- entity="wetlab.Compound", name="chebi"
2767
- ).first()
2768
- return sources
2769
-
2770
- def _validate_initial_data(self, adata: ad.AnnData):
2771
- """Validate the initial data structure."""
2772
- self._validate_required_columns(adata)
2773
- self._validate_perturbation_types(adata)
2774
-
2775
- def _validate_required_columns(self, adata: ad.AnnData):
2776
- """Validate required columns are present."""
2777
- if "pert_target" not in adata.obs.columns:
2778
- if (
2779
- "pert_name" not in adata.obs.columns
2780
- or "pert_type" not in adata.obs.columns
2781
- ):
2782
- raise ValidationError(
2783
- "either 'pert_target' or both 'pert_name' and 'pert_type' must be present"
2784
- )
2785
- else:
2786
- if "pert_name" not in adata.obs.columns:
2787
- logger.warning(
2788
- "no 'pert' column found in adata.obs, will only curate 'pert_target'"
2789
- )
2790
- elif "pert_type" not in adata.obs.columns:
2791
- raise ValidationError("both 'pert' and 'pert_type' must be present")
2792
-
2793
- def _validate_perturbation_types(self, adata: ad.AnnData):
2794
- """Validate perturbation types."""
2795
- if "pert_type" in adata.obs.columns:
2796
- data_pert_types = set(adata.obs["pert_type"].unique())
2797
- invalid_pert_types = data_pert_types - self.PERT_COLUMNS
2798
- if invalid_pert_types:
2799
- raise ValidationError(
2800
- f"invalid pert_type found: {invalid_pert_types}!\n"
2801
- f" → allowed values: {self.PERT_COLUMNS}"
2802
- )
2803
- self._process_perturbation_types(adata, data_pert_types)
2804
-
2805
- def _process_perturbation_types(self, adata: ad.AnnData, pert_types: set):
2806
- """Process and map perturbation types."""
2807
- for pert_type in pert_types:
2808
- col_name = "pert_" + pert_type
2809
- adata.obs[col_name] = adata.obs["pert_name"].where(
2810
- adata.obs["pert_type"] == pert_type, None
2811
- )
2812
- if adata.obs[col_name].dtype.name == "category":
2813
- adata.obs[col_name].cat.remove_unused_categories()
2814
- logger.important(f"mapped 'pert_name' to '{col_name}'")
2815
-
2816
- def validate(self) -> bool: # type: ignore
2817
- """Validate the AnnData object."""
2818
- validated = super().validate()
2819
-
2820
- if self._pert_dose:
2821
- validated &= self._validate_dose_column()
2822
- if self._pert_time:
2823
- validated &= self._validate_time_column()
2824
-
2825
- self._is_validated = validated
2826
-
2827
- # sort columns
2828
- first_columns = [
2829
- "pert_target",
2830
- "pert_genetic",
2831
- "pert_compound",
2832
- "pert_biologic",
2833
- "pert_physical",
2834
- "pert_dose",
2835
- "pert_time",
2836
- "organism",
2837
- "cell_line",
2838
- "cell_type",
2839
- "disease",
2840
- "tissue_type",
2841
- "tissue",
2842
- "assay",
2843
- "suspension_type",
2844
- "donor_id",
2845
- "sex",
2846
- "self_reported_ethnicity",
2847
- "development_stage",
2848
- "pert_name",
2849
- "pert_type",
2850
- ]
2851
- sorted_columns = [
2852
- col for col in first_columns if col in self._adata.obs.columns
2853
- ] + [col for col in self._adata.obs.columns if col not in first_columns]
2854
- # must assign to self._df to ensure .standardize works correctly
2855
- self._obs_df = self._adata.obs[sorted_columns]
2856
- self._adata.obs = self._obs_df
2857
- return validated
2858
-
2859
- def standardize(self, key: str) -> pd.DataFrame:
2860
- """Standardize the AnnData object."""
2861
- super().standardize(key)
2862
- self._adata.obs = self._obs_df
2863
-
2864
- def _validate_dose_column(self) -> bool:
2865
- """Validate the dose column."""
2866
- if not Feature.filter(name="pert_dose").exists():
2867
- Feature(name="pert_dose", dtype="str").save() # type: ignore
2868
-
2869
- dose_errors = DoseHandler.validate_values(self._adata.obs["pert_dose"])
2870
- if dose_errors:
2871
- self._log_validation_errors("pert_dose", dose_errors)
2872
- return False
2873
- return True
2874
-
2875
- def _validate_time_column(self) -> bool:
2876
- """Validate the time column."""
2877
- if not Feature.filter(name="pert_time").exists():
2878
- Feature(name="pert_time", dtype="str").save() # type: ignore
2879
-
2880
- time_errors = TimeHandler.validate_values(self._adata.obs["pert_time"])
2881
- if time_errors:
2882
- self._log_validation_errors("pert_time", time_errors)
2883
- return False
2884
- return True
2885
-
2886
- def _log_validation_errors(self, column: str, errors: list):
2887
- """Log validation errors with formatting."""
2888
- errors_print = "\n ".join(errors)
2889
- logger.warning(
2890
- f"invalid {column} values found!\n {errors_print}\n"
2891
- f" → run {colors.cyan('standardize_dose_time()')}"
2892
- )
2893
-
2894
- def standardize_dose_time(self) -> pd.DataFrame:
2895
- """Standardize dose and time values."""
2896
- standardized_df = self._adata.obs.copy()
2897
-
2898
- if "pert_dose" in self._adata.obs.columns:
2899
- standardized_df = self._standardize_column(
2900
- standardized_df, "pert_dose", is_dose=True
2901
- )
2902
-
2903
- if "pert_time" in self._adata.obs.columns:
2904
- standardized_df = self._standardize_column(
2905
- standardized_df, "pert_time", is_dose=False
2906
- )
2907
-
2908
- self._adata.obs = standardized_df
2909
- return standardized_df
2910
-
2911
- def _standardize_column(
2912
- self, df: pd.DataFrame, column: str, is_dose: bool
2913
- ) -> pd.DataFrame:
2914
- """Standardize values in a specific column."""
2915
- for idx, value in self._adata.obs[column].items():
2916
- if pd.isna(value) or (
2917
- isinstance(value, str) and (not value.strip() or value.lower() == "nan")
2918
- ):
2919
- df.at[idx, column] = None
2920
- continue
2921
-
2922
- try:
2923
- num, unit = ValueUnit.parse_value_unit(value, is_dose=is_dose)
2924
- df.at[idx, column] = f"{num}{unit}"
2925
- except ValueError:
2926
- continue
2927
-
2928
- return df
2929
-
2930
-
2931
- def get_current_filter_kwargs(registry: type[Record], kwargs: dict) -> dict:
2932
- """Make sure the source and organism are saved in the same database as the registry."""
2933
- db = registry.filter().db
2934
- source = kwargs.get("source")
2935
- organism = kwargs.get("organism")
2936
- filter_kwargs = kwargs.copy()
2937
-
2938
- if isinstance(organism, Record) and organism._state.db != "default":
2939
- if db is None or db == "default":
2940
- organism_default = copy.copy(organism)
2941
- # save the organism record in the default database
2942
- organism_default.save()
2943
- filter_kwargs["organism"] = organism_default
2944
- if isinstance(source, Record) and source._state.db != "default":
2945
- if db is None or db == "default":
2946
- source_default = copy.copy(source)
2947
- # save the source record in the default database
2948
- source_default.save()
2949
- filter_kwargs["source"] = source_default
2950
-
2951
- return filter_kwargs
2952
-
2953
-
2954
- def get_organism_kwargs(
2955
- field: FieldAttr, organism: str | None = None
2956
- ) -> dict[str, str]:
2957
- """Check if a registry needs an organism and return the organism name."""
2958
- registry = field.field.model
2959
- if registry.__base__.__name__ == "BioRecord":
2960
- import bionty as bt
2961
- from bionty._organism import is_organism_required
2962
-
2963
- from ..models._from_values import get_organism_record_from_field
2964
-
2965
- if is_organism_required(registry):
2966
- if organism is not None or bt.settings.organism is not None:
2967
- return {"organism": organism or bt.settings.organism.name}
2968
- else:
2969
- organism_record = get_organism_record_from_field(
2970
- field, organism=organism
2971
- )
2972
- if organism_record is not None:
2973
- return {"organism": organism_record.name}
2974
- return {}
2975
-
2976
-
2977
- def validate_categories(
2978
- values: Iterable[str],
2979
- field: FieldAttr,
2980
- key: str,
2981
- organism: str | None = None,
2982
- source: Record | None = None,
2983
- hint_print: str | None = None,
2984
- curator: CatManager | None = None,
2985
- ) -> tuple[bool, list[str]]:
2986
- """Validate ontology terms using LaminDB registries.
2987
-
2988
- Args:
2989
- values: The values to validate.
2990
- field: The field attribute.
2991
- key: The key referencing the slot in the DataFrame.
2992
- organism: The organism name.
2993
- source: The source record.
2994
- standardize: Whether to standardize the values.
2995
- hint_print: The hint to print that suggests fixing non-validated values.
2996
- """
2997
- model_field = f"{field.field.model.__name__}.{field.field.name}"
2998
-
2999
- def _log_mapping_info():
3000
- logger.indent = ""
3001
- logger.info(f'mapping "{key}" on {colors.italic(model_field)}')
3002
- logger.indent = " "
3003
-
3004
- registry = field.field.model
3005
-
3006
- kwargs_current = get_current_filter_kwargs(
3007
- registry, {"organism": organism, "source": source}
3008
- )
3009
-
3010
- # inspect values from the default instance
3011
- inspect_result = registry.inspect(values, field=field, mute=True, **kwargs_current)
3012
- non_validated = inspect_result.non_validated
3013
- syn_mapper = inspect_result.synonyms_mapper
3014
-
3015
- # inspect the non-validated values from public (BioRecord only)
3016
- values_validated = []
3017
- if hasattr(registry, "public"):
3018
- public_records = registry.from_values(
3019
- non_validated,
3020
- field=field,
3021
- mute=True,
3022
- **kwargs_current,
3023
- )
3024
- values_validated += [getattr(r, field.field.name) for r in public_records]
3025
-
3026
- # logging messages
3027
- non_validated_hint_print = hint_print or f'.add_new_from("{key}")'
3028
- non_validated = [i for i in non_validated if i not in values_validated]
3029
- n_non_validated = len(non_validated)
3030
- if n_non_validated == 0:
3031
- logger.indent = ""
3032
- logger.success(f'"{key}" is validated against {colors.italic(model_field)}')
3033
- return True, []
3034
- else:
3035
- are = "is" if n_non_validated == 1 else "are"
3036
- s = "" if n_non_validated == 1 else "s"
3037
- print_values = _format_values(non_validated)
3038
- warning_message = f"{colors.red(f'{n_non_validated} term{s}')} {are} not validated: {colors.red(print_values)}\n"
3039
- if syn_mapper:
3040
- s = "" if len(syn_mapper) == 1 else "s"
3041
- syn_mapper_print = _format_values(
3042
- [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
3043
- )
3044
- hint_msg = f'.standardize("{key}")'
3045
- warning_message += f" {colors.yellow(f'{len(syn_mapper)} synonym{s}')} found: {colors.yellow(syn_mapper_print)}\n → curate synonyms via {colors.cyan(hint_msg)}"
3046
- if n_non_validated > len(syn_mapper):
3047
- if syn_mapper:
3048
- warning_message += "\n for remaining terms:\n"
3049
- warning_message += f" → fix typos, remove non-existent values, or save terms via {colors.cyan(non_validated_hint_print)}"
3050
-
3051
- if logger.indent == "":
3052
- _log_mapping_info()
3053
- logger.warning(warning_message)
3054
- if curator is not None:
3055
- curator._validate_category_error_messages = strip_ansi_codes(
3056
- warning_message
3057
- )
3058
- logger.indent = ""
3059
- return False, non_validated
3060
-
3061
-
3062
- def standardize_categories(
3063
- values: Iterable[str],
3064
- field: FieldAttr,
3065
- organism: str | None = None,
3066
- source: Record | None = None,
3067
- ) -> dict:
3068
- """Get a synonym mapper."""
3069
- registry = field.field.model
3070
- if not hasattr(registry, "standardize"):
3071
- return {}
3072
- # standardize values using the default instance
3073
- syn_mapper = registry.standardize(
3074
- values,
3075
- field=field.field.name,
3076
- organism=organism,
3077
- source=source,
3078
- mute=True,
3079
- return_mapper=True,
3080
- )
3081
- return syn_mapper
3082
-
3083
-
3084
- def validate_categories_in_df(
3085
- df: pd.DataFrame,
3086
- fields: dict[str, FieldAttr],
3087
- sources: dict[str, Record] = None,
3088
- curator: CatManager | None = None,
3089
- **kwargs,
3090
- ) -> tuple[bool, dict]:
3091
- """Validate categories in DataFrame columns using LaminDB registries."""
3092
- if not fields:
3093
- return True, {}
3094
-
3095
- if sources is None:
3096
- sources = {}
3097
- validated = True
3098
- non_validated = {}
3099
- for key, field in fields.items():
3100
- is_val, non_val = validate_categories(
3101
- df[key],
3102
- field=field,
3103
- key=key,
3104
- source=sources.get(key),
3105
- curator=curator,
3106
- **kwargs,
3107
- )
3108
- validated &= is_val
3109
- if len(non_val) > 0:
3110
- non_validated[key] = non_val
3111
- return validated, non_validated
3112
-
3113
-
3114
- def save_artifact(
3115
- data: pd.DataFrame | ScverseDataStructures,
3116
- *,
3117
- fields: dict[str, FieldAttr] | dict[str, dict[str, FieldAttr]],
3118
- index_field: FieldAttr | dict[str, FieldAttr] | None = None,
3119
- description: str | None = None,
3120
- organism: str | None = None,
3121
- key: str | None = None,
3122
- artifact: Artifact | None = None,
3123
- revises: Artifact | None = None,
3124
- run: Run | None = None,
3125
- schema: Schema | None = None,
3126
- **kwargs,
3127
- ) -> Artifact:
3128
- """Save all metadata with an Artifact.
3129
-
3130
- Args:
3131
- data: The object to save.
3132
- fields: A dictionary mapping obs_column to registry_field.
3133
- index_field: The registry field to validate variables index against.
3134
- description: A description of the artifact.
3135
- organism: The organism name.
3136
- key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
3137
- artifact: A already registered artifact. Passing this will not save a new artifact from data.
3138
- revises: Previous version of the artifact. Triggers a revision.
3139
- run: The run that creates the artifact.
3140
- schema: The Schema to associate with the Artifact.
3141
-
3142
- Returns:
3143
- The saved Artifact.
3144
- """
3145
- from ..models.artifact import add_labels
3146
-
3147
- if artifact is None:
3148
- if isinstance(data, pd.DataFrame):
3149
- artifact = Artifact.from_df(
3150
- data, description=description, key=key, revises=revises, run=run
3151
- )
3152
- elif isinstance(data, AnnData):
3153
- artifact = Artifact.from_anndata(
3154
- data, description=description, key=key, revises=revises, run=run
3155
- )
3156
- elif data_is_mudata(data):
3157
- artifact = Artifact.from_mudata(
3158
- data, description=description, key=key, revises=revises, run=run
3159
- )
3160
- elif data_is_spatialdata(data):
3161
- artifact = Artifact.from_spatialdata(
3162
- data, description=description, key=key, revises=revises, run=run
3163
- )
3164
- else:
3165
- raise InvalidArgument( # pragma: no cover
3166
- "data must be one of pd.Dataframe, AnnData, MuData, SpatialData."
3167
- )
3168
- artifact.save()
3169
-
3170
- def _add_labels(
3171
- data: pd.DataFrame | ScverseDataStructures,
3172
- artifact: Artifact,
3173
- fields: dict[str, FieldAttr],
3174
- feature_ref_is_name: bool | None = None,
3175
- ):
3176
- features = Feature.lookup().dict()
3177
- for key, field in fields.items():
3178
- feature = features.get(key)
3179
- registry = field.field.model
3180
- # we don't need source here because all records are already in the DB
3181
- filter_kwargs = get_current_filter_kwargs(registry, {"organism": organism})
3182
- df = data if isinstance(data, pd.DataFrame) else data.obs
3183
- # multi-value columns are separated by "|"
3184
- if not df[key].isna().all() and df[key].str.contains("|").any():
3185
- values = df[key].str.split("|").explode().unique()
3186
- else:
3187
- values = df[key].unique()
3188
- labels = registry.from_values(values, field=field, **filter_kwargs)
3189
- if len(labels) == 0:
3190
- continue
3191
- label_ref_is_name = None
3192
- if hasattr(registry, "_name_field"):
3193
- label_ref_is_name = field.field.name == registry._name_field
3194
- add_labels(
3195
- artifact,
3196
- records=labels,
3197
- feature=feature,
3198
- feature_ref_is_name=feature_ref_is_name,
3199
- label_ref_is_name=label_ref_is_name,
3200
- from_curator=True,
3201
- )
3202
-
3203
- match artifact.otype:
3204
- case "DataFrame":
3205
- artifact.features._add_set_from_df(field=index_field, organism=organism) # type: ignore
3206
- _add_labels(
3207
- data, artifact, fields, feature_ref_is_name=_ref_is_name(index_field)
3208
- )
3209
- case "AnnData":
3210
- if schema is not None and "uns" in schema.slots:
3211
- uns_field = parse_cat_dtype(schema.slots["uns"].itype, is_itype=True)[
3212
- "field"
3213
- ]
3214
- else:
3215
- uns_field = None
3216
- artifact.features._add_set_from_anndata( # type: ignore
3217
- var_field=index_field, uns_field=uns_field, organism=organism
3218
- )
3219
- _add_labels(
3220
- data, artifact, fields, feature_ref_is_name=_ref_is_name(index_field)
3221
- )
3222
- case "MuData":
3223
- artifact.features._add_set_from_mudata( # type: ignore
3224
- var_fields=index_field, organism=organism
3225
- )
3226
- for modality, modality_fields in fields.items():
3227
- column_field_modality = index_field.get(modality)
3228
- if modality == "obs":
3229
- _add_labels(
3230
- data,
3231
- artifact,
3232
- modality_fields,
3233
- feature_ref_is_name=(
3234
- None
3235
- if column_field_modality is None
3236
- else _ref_is_name(column_field_modality)
3237
- ),
3238
- )
3239
- else:
3240
- _add_labels(
3241
- data[modality],
3242
- artifact,
3243
- modality_fields,
3244
- feature_ref_is_name=(
3245
- None
3246
- if column_field_modality is None
3247
- else _ref_is_name(column_field_modality)
3248
- ),
3249
- )
3250
- case "SpatialData":
3251
- artifact.features._add_set_from_spatialdata( # type: ignore
3252
- sample_metadata_key=kwargs.get("sample_metadata_key", "sample"),
3253
- var_fields=index_field,
3254
- organism=organism,
3255
- )
3256
- sample_metadata_key = kwargs.get("sample_metadata_key", "sample")
3257
- for accessor, accessor_fields in fields.items():
3258
- column_field = index_field.get(accessor)
3259
- if accessor == sample_metadata_key:
3260
- _add_labels(
3261
- data.get_attrs(
3262
- key=sample_metadata_key, return_as="df", flatten=True
3263
- ),
3264
- artifact,
3265
- accessor_fields,
3266
- feature_ref_is_name=(
3267
- None if column_field is None else _ref_is_name(column_field)
3268
- ),
3269
- )
3270
- else:
3271
- _add_labels(
3272
- data.tables[accessor],
3273
- artifact,
3274
- accessor_fields,
3275
- feature_ref_is_name=(
3276
- None if column_field is None else _ref_is_name(column_field)
3277
- ),
3278
- )
3279
- case _:
3280
- raise NotImplementedError # pragma: no cover
3281
-
3282
- artifact.schema = schema
3283
- artifact.save()
3284
-
3285
- slug = ln_setup.settings.instance.slug
3286
- if ln_setup.settings.instance.is_remote: # pdagma: no cover
3287
- logger.important(f"go to https://lamin.ai/{slug}/artifact/{artifact.uid}")
3288
- return artifact
3289
-
3290
-
3291
- def _flatten_unique(series: pd.Series[list[Any] | Any]) -> list[Any]:
3292
- """Flatten a Pandas series containing lists or single items into a unique list of elements."""
3293
- result = set()
3294
-
3295
- for item in series:
3296
- if isinstance(item, list):
3297
- result.update(item)
3298
- else:
3299
- result.add(item)
3300
-
3301
- return list(result)
3302
-
3303
-
3304
- def update_registry(
3305
- values: list[str],
3306
- field: FieldAttr,
3307
- key: str,
3308
- validated_only: bool = True,
3309
- df: pd.DataFrame | None = None,
3310
- organism: str | None = None,
3311
- dtype: str | None = None,
3312
- source: Record | None = None,
3313
- **create_kwargs,
3314
- ) -> None:
3315
- """Save features or labels records in the default instance..
3316
-
3317
- Args:
3318
- values: A list of values to be saved as labels.
3319
- field: The FieldAttr object representing the field for which labels are being saved.
3320
- key: The name of the feature to save.
3321
- validated_only: If True, only save validated labels.
3322
- df: A DataFrame to save labels from.
3323
- organism: The organism name.
3324
- dtype: The type of the feature.
3325
- source: The source record.
3326
- **create_kwargs: Additional keyword arguments to pass to the registry model to create new records.
3327
- """
3328
- from lamindb.models.save import save as ln_save
3329
-
3330
- registry = field.field.model
3331
- filter_kwargs = get_current_filter_kwargs(
3332
- registry, {"organism": organism, "source": source}
3333
- )
3334
- values = [i for i in values if isinstance(i, str) and i]
3335
- if not values:
3336
- return
3337
-
3338
- labels_saved: dict = {"from public": [], "new": []}
3339
-
3340
- # inspect the default instance and save validated records from public
3341
- existing_and_public_records = registry.from_values(
3342
- list(values), field=field, **filter_kwargs, mute=True
3343
- )
3344
- existing_and_public_labels = [
3345
- getattr(r, field.field.name) for r in existing_and_public_records
3346
- ]
3347
- # public records that are not already in the database
3348
- public_records = [r for r in existing_and_public_records if r._state.adding]
3349
- # here we check to only save the public records if they are from the specified source
3350
- # we check the uid because r.source and source can be from different instances
3351
- if source:
3352
- public_records = [r for r in public_records if r.source.uid == source.uid]
3353
- if len(public_records) > 0:
3354
- logger.info(f"saving validated records of '{key}'")
3355
- ln_save(public_records)
3356
- labels_saved["from public"] = [
3357
- getattr(r, field.field.name) for r in public_records
3358
- ]
3359
- # non-validated records from the default instance
3360
- non_validated_labels = [i for i in values if i not in existing_and_public_labels]
3361
-
3362
- # save non-validated/new records
3363
- labels_saved["new"] = non_validated_labels
3364
- if not validated_only:
3365
- non_validated_records: RecordList[Any] = [] # type: ignore
3366
- if df is not None and registry == Feature:
3367
- nonval_columns = Feature.inspect(df.columns, mute=True).non_validated
3368
- non_validated_records = Feature.from_df(df.loc[:, nonval_columns])
3369
- else:
3370
- if (
3371
- organism
3372
- and hasattr(registry, "organism")
3373
- and registry._meta.get_field("organism").is_relation
3374
- ):
3375
- # make sure organism record is saved to the current instance
3376
- create_kwargs["organism"] = _save_organism(name=organism)
3377
-
3378
- for value in labels_saved["new"]:
3379
- init_kwargs = {field.field.name: value}
3380
- if registry == Feature:
3381
- init_kwargs["dtype"] = "cat" if dtype is None else dtype
3382
- non_validated_records.append(registry(**init_kwargs, **create_kwargs))
3383
- ln_save(non_validated_records)
3384
-
3385
- # save parent labels for ulabels, for example a parent label "project" for label "project001"
3386
- if registry == ULabel and field.field.name == "name":
3387
- save_ulabels_type(values, field=field, key=key)
3388
-
3389
- log_saved_labels(
3390
- labels_saved,
3391
- key=key,
3392
- model_field=f"{registry.__name__}.{field.field.name}",
3393
- validated_only=validated_only,
3394
- )
3395
-
3396
-
3397
- def log_saved_labels(
3398
- labels_saved: dict,
3399
- key: str,
3400
- model_field: str,
3401
- validated_only: bool = True,
3402
- ) -> None:
3403
- """Log the saved labels."""
3404
- from ..models._from_values import _format_values
3405
-
3406
- model_field = colors.italic(model_field)
3407
- for k, labels in labels_saved.items():
3408
- if not labels:
3409
- continue
3410
- if k == "new" and validated_only:
3411
- continue
3412
- else:
3413
- k = "" if k == "new" else f"{colors.green(k)} "
3414
- # the term "transferred" stresses that this is always in the context of transferring
3415
- # labels from a public ontology or a different instance to the present instance
3416
- s = "s" if len(labels) > 1 else ""
3417
- logger.success(
3418
- f'added {len(labels)} record{s} {k}with {model_field} for "{key}": {_format_values(labels)}'
3419
- )
3420
-
3421
-
3422
- def save_ulabels_type(values: list[str], field: FieldAttr, key: str) -> None:
3423
- """Save the ULabel type of the given labels."""
3424
- registry = field.field.model
3425
- assert registry == ULabel # noqa: S101
3426
- all_records = registry.filter(**{field.field.name: list(values)}).all()
3427
- # so `tissue_type` becomes `TissueType`
3428
- type_name = "".join([i.capitalize() for i in key.lower().split("_")])
3429
- ulabel_type = registry.filter(name=type_name, is_type=True).one_or_none()
3430
- if ulabel_type is None:
3431
- ulabel_type = registry(name=type_name, is_type=True).save()
3432
- logger.important(f"Created a ULabel type: {ulabel_type}")
3433
- all_records.update(type=ulabel_type)
3434
-
3435
-
3436
- def _save_organism(name: str):
3437
- """Save an organism record."""
3438
- import bionty as bt
3439
-
3440
- organism = bt.Organism.filter(name=name).one_or_none()
3441
- if organism is None:
3442
- organism = bt.Organism.from_source(name=name)
3443
- if organism is None:
3444
- raise ValidationError(
3445
- f'Organism "{name}" not found from public reference\n'
3446
- f' → please save it from a different source: bt.Organism.from_source(name="{name}", source).save()'
3447
- f' → or manually save it without source: bt.Organism(name="{name}").save()'
3448
- )
3449
- organism.save()
3450
- return organism
3451
-
3452
-
3453
- def _ref_is_name(field: FieldAttr | None) -> bool | None:
3454
- """Check if the reference field is a name field."""
3455
- from ..models.can_curate import get_name_field
3456
-
3457
- if field is not None:
3458
- name_field = get_name_field(field.field.model)
3459
- return field.field.name == name_field
3460
- return None
3461
-
3462
-
3463
- # backward compat constructors ------------------
3464
-
3465
-
3466
- @classmethod # type: ignore
3467
- def from_df(
3468
- cls,
3469
- df: pd.DataFrame,
3470
- categoricals: dict[str, FieldAttr] | None = None,
3471
- columns: FieldAttr = Feature.name,
3472
- verbosity: str = "hint",
3473
- organism: str | None = None,
3474
- ) -> DataFrameCatManager:
3475
- return DataFrameCatManager(
3476
- df=df,
3477
- categoricals=categoricals,
3478
- columns=columns,
3479
- verbosity=verbosity,
3480
- organism=organism,
3481
- )
3482
-
3483
-
3484
- @classmethod # type: ignore
3485
- def from_anndata(
3486
- cls,
3487
- data: ad.AnnData | UPathStr,
3488
- var_index: FieldAttr,
3489
- categoricals: dict[str, FieldAttr] | None = None,
3490
- obs_columns: FieldAttr = Feature.name,
3491
- verbosity: str = "hint",
3492
- organism: str | None = None,
3493
- sources: dict[str, Record] | None = None,
3494
- ) -> AnnDataCatManager:
3495
- return AnnDataCatManager(
3496
- data=data,
3497
- var_index=var_index,
3498
- categoricals=categoricals,
3499
- obs_columns=obs_columns,
3500
- verbosity=verbosity,
3501
- organism=organism,
3502
- sources=sources,
3503
- )
3504
-
3505
-
3506
- @classmethod # type: ignore
3507
- def from_mudata(
3508
- cls,
3509
- mdata: MuData | UPathStr,
3510
- var_index: dict[str, dict[str, FieldAttr]],
3511
- categoricals: dict[str, FieldAttr] | None = None,
3512
- verbosity: str = "hint",
3513
- organism: str | None = None,
3514
- ) -> MuDataCatManager:
3515
- return MuDataCatManager(
3516
- mdata=mdata,
3517
- var_index=var_index,
3518
- categoricals=categoricals,
3519
- verbosity=verbosity,
3520
- organism=organism,
3521
- )
3522
-
3523
-
3524
- @classmethod # type: ignore
3525
- def from_tiledbsoma(
3526
- cls,
3527
- experiment_uri: UPathStr,
3528
- var_index: dict[str, tuple[str, FieldAttr]],
3529
- categoricals: dict[str, FieldAttr] | None = None,
3530
- obs_columns: FieldAttr = Feature.name,
3531
- organism: str | None = None,
3532
- sources: dict[str, Record] | None = None,
3533
- ) -> TiledbsomaCatManager:
3534
- return TiledbsomaCatManager(
3535
- experiment_uri=experiment_uri,
3536
- var_index=var_index,
3537
- categoricals=categoricals,
3538
- obs_columns=obs_columns,
3539
- organism=organism,
3540
- sources=sources,
3541
- )
3542
-
3543
-
3544
- @classmethod # type: ignore
3545
- def from_spatialdata(
3546
- cls,
3547
- sdata: SpatialData | UPathStr,
3548
- var_index: dict[str, FieldAttr],
3549
- categoricals: dict[str, dict[str, FieldAttr]] | None = None,
3550
- organism: str | None = None,
3551
- sources: dict[str, dict[str, Record]] | None = None,
3552
- verbosity: str = "hint",
3553
- *,
3554
- sample_metadata_key: str = "sample",
3555
- ):
3556
- try:
3557
- import spatialdata
3558
- except ImportError as e:
3559
- raise ImportError("Please install spatialdata: pip install spatialdata") from e
3560
-
3561
- return SpatialDataCatManager(
3562
- sdata=sdata,
3563
- var_index=var_index,
3564
- categoricals=categoricals,
3565
- verbosity=verbosity,
3566
- organism=organism,
3567
- sources=sources,
3568
- sample_metadata_key=sample_metadata_key,
3569
- )
3570
-
3571
-
3572
- CatManager.from_df = from_df # type: ignore
3573
- CatManager.from_anndata = from_anndata # type: ignore
3574
- CatManager.from_mudata = from_mudata # type: ignore
3575
- CatManager.from_spatialdata = from_spatialdata # type: ignore
3576
- CatManager.from_tiledbsoma = from_tiledbsoma # type: ignore
24
+ from .core import AnnDataCurator, DataFrameCurator, MuDataCurator, SpatialDataCurator