lamindb 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. lamindb/__init__.py +52 -36
  2. lamindb/_finish.py +17 -10
  3. lamindb/_tracked.py +1 -1
  4. lamindb/base/__init__.py +3 -1
  5. lamindb/base/fields.py +40 -22
  6. lamindb/base/ids.py +1 -94
  7. lamindb/base/types.py +2 -0
  8. lamindb/base/uids.py +117 -0
  9. lamindb/core/_context.py +203 -102
  10. lamindb/core/_settings.py +38 -25
  11. lamindb/core/datasets/__init__.py +11 -4
  12. lamindb/core/datasets/_core.py +5 -5
  13. lamindb/core/datasets/_small.py +0 -93
  14. lamindb/core/datasets/mini_immuno.py +172 -0
  15. lamindb/core/loaders.py +1 -1
  16. lamindb/core/storage/_backed_access.py +100 -6
  17. lamindb/core/storage/_polars_lazy_df.py +51 -0
  18. lamindb/core/storage/_pyarrow_dataset.py +15 -30
  19. lamindb/core/storage/_tiledbsoma.py +29 -13
  20. lamindb/core/storage/objects.py +6 -0
  21. lamindb/core/subsettings/__init__.py +2 -0
  22. lamindb/core/subsettings/_annotation_settings.py +11 -0
  23. lamindb/curators/__init__.py +7 -3349
  24. lamindb/curators/_legacy.py +2056 -0
  25. lamindb/curators/core.py +1534 -0
  26. lamindb/errors.py +11 -0
  27. lamindb/examples/__init__.py +27 -0
  28. lamindb/examples/schemas/__init__.py +12 -0
  29. lamindb/examples/schemas/_anndata.py +25 -0
  30. lamindb/examples/schemas/_simple.py +19 -0
  31. lamindb/integrations/_vitessce.py +8 -5
  32. lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
  33. lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
  34. lamindb/migrations/0093_alter_schemacomponent_unique_together.py +16 -0
  35. lamindb/models/__init__.py +4 -1
  36. lamindb/models/_describe.py +21 -4
  37. lamindb/models/_feature_manager.py +382 -287
  38. lamindb/models/_label_manager.py +8 -2
  39. lamindb/models/artifact.py +177 -106
  40. lamindb/models/artifact_set.py +122 -0
  41. lamindb/models/collection.py +73 -52
  42. lamindb/models/core.py +1 -1
  43. lamindb/models/feature.py +51 -17
  44. lamindb/models/has_parents.py +69 -14
  45. lamindb/models/project.py +1 -1
  46. lamindb/models/query_manager.py +221 -22
  47. lamindb/models/query_set.py +247 -172
  48. lamindb/models/record.py +65 -247
  49. lamindb/models/run.py +4 -4
  50. lamindb/models/save.py +8 -2
  51. lamindb/models/schema.py +456 -184
  52. lamindb/models/transform.py +2 -2
  53. lamindb/models/ulabel.py +8 -5
  54. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/METADATA +6 -6
  55. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/RECORD +57 -43
  56. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/LICENSE +0 -0
  57. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,2056 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from itertools import chain
5
+ from typing import TYPE_CHECKING, Any, Iterable, Literal
6
+
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+ from lamin_utils import colors, logger
10
+ from lamindb_setup.core import deprecated
11
+ from lamindb_setup.core.upath import UPath
12
+
13
+ from lamindb.core._compat import is_package_installed
14
+ from lamindb.models.artifact import (
15
+ data_is_anndata,
16
+ data_is_mudata,
17
+ data_is_spatialdata,
18
+ )
19
+
20
+ from ..errors import InvalidArgument
21
+
22
+ if TYPE_CHECKING:
23
+ from lamindb_setup.core.types import UPathStr
24
+ from mudata import MuData
25
+ from spatialdata import SpatialData
26
+
27
+ from lamindb.models import Record
28
+ from lamindb.base.types import FieldAttr # noqa
29
+ from lamindb.models import (
30
+ Artifact,
31
+ Feature,
32
+ Record,
33
+ Run,
34
+ Schema,
35
+ )
36
+ from lamindb.models.artifact import (
37
+ add_labels,
38
+ )
39
+ from lamindb.models._from_values import _format_values
40
+ from .core import CatLookup, CatVector
41
+ from ..errors import ValidationError
42
+ import anndata as ad
43
+
44
+
45
+ def _ref_is_name(field: FieldAttr | None) -> bool | None:
46
+ """Check if the reference field is a name field."""
47
+ from ..models.can_curate import get_name_field
48
+
49
+ if field is not None:
50
+ name_field = get_name_field(field.field.model)
51
+ return field.field.name == name_field
52
+ return None
53
+
54
+
55
+ class CatManager:
56
+ """Manage categoricals by updating registries.
57
+
58
+ This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
59
+
60
+ If you find non-validated values, you have several options:
61
+
62
+ - new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.add_new_from`
63
+ - non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.DataFrameCatManager.non_validated` and addressed manually
64
+ """
65
+
66
+ def __init__(self, *, dataset, categoricals, sources, columns_field=None):
67
+ # the below is shared with Curator
68
+ self._artifact: Artifact = None # pass the dataset as an artifact
69
+ self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
70
+ if isinstance(self._dataset, Artifact):
71
+ self._artifact = self._dataset
72
+ if self._artifact.otype in {"DataFrame", "AnnData"}:
73
+ self._dataset = self._dataset.load(
74
+ is_run_input=False # we already track this in the Curator constructor
75
+ )
76
+ self._is_validated: bool = False
77
+ # shared until here
78
+ self._categoricals = categoricals or {}
79
+ self._non_validated = None
80
+ self._sources = sources or {}
81
+ self._columns_field = columns_field
82
+ self._validate_category_error_messages: str = ""
83
+ self._cat_vectors: dict[str, CatVector] = {}
84
+
85
+ @property
86
+ def non_validated(self) -> dict[str, list[str]]:
87
+ """Return the non-validated features and labels."""
88
+ if self._non_validated is None:
89
+ raise ValidationError("Please run validate() first!")
90
+ return {
91
+ key: cat_vector._non_validated
92
+ for key, cat_vector in self._cat_vectors.items()
93
+ if cat_vector._non_validated and key != "columns"
94
+ }
95
+
96
+ @property
97
+ def categoricals(self) -> dict:
98
+ """Return the columns fields to validate against."""
99
+ return self._categoricals
100
+
101
+ def validate(self) -> bool:
102
+ """Validate dataset.
103
+
104
+ This method also registers the validated records in the current instance.
105
+
106
+ Returns:
107
+ The boolean `True` if the dataset is validated. Otherwise, a string with the error message.
108
+ """
109
+ pass # pragma: no cover
110
+
111
+ def standardize(self, key: str) -> None:
112
+ """Replace synonyms with standardized values.
113
+
114
+ Inplace modification of the dataset.
115
+
116
+ Args:
117
+ key: The name of the column to standardize.
118
+
119
+ Returns:
120
+ None
121
+ """
122
+ pass # pragma: no cover
123
+
124
+ def save_artifact(
125
+ self,
126
+ *,
127
+ key: str | None = None,
128
+ description: str | None = None,
129
+ revises: Artifact | None = None,
130
+ run: Run | None = None,
131
+ ) -> Artifact:
132
+ """{}""" # noqa: D415
133
+ # Make sure all labels are saved in the current instance
134
+ if not self._is_validated:
135
+ self.validate() # returns True or False
136
+ if not self._is_validated: # need to raise error manually
137
+ raise ValidationError("Dataset does not validate. Please curate.")
138
+
139
+ if self._artifact is None:
140
+ if isinstance(self._dataset, pd.DataFrame):
141
+ artifact = Artifact.from_df(
142
+ self._dataset,
143
+ key=key,
144
+ description=description,
145
+ revises=revises,
146
+ run=run,
147
+ )
148
+ elif isinstance(self._dataset, ad.AnnData):
149
+ artifact = Artifact.from_anndata(
150
+ self._dataset,
151
+ key=key,
152
+ description=description,
153
+ revises=revises,
154
+ run=run,
155
+ )
156
+ elif data_is_mudata(self._dataset):
157
+ artifact = Artifact.from_mudata(
158
+ self._dataset,
159
+ key=key,
160
+ description=description,
161
+ revises=revises,
162
+ run=run,
163
+ )
164
+ elif data_is_spatialdata(self._dataset):
165
+ artifact = Artifact.from_spatialdata(
166
+ self._dataset,
167
+ key=key,
168
+ description=description,
169
+ revises=revises,
170
+ run=run,
171
+ )
172
+ else:
173
+ raise InvalidArgument( # pragma: no cover
174
+ "data must be one of pd.Dataframe, AnnData, MuData, SpatialData."
175
+ )
176
+ self._artifact = artifact.save()
177
+
178
+ legacy_annotate_artifact( # type: ignore
179
+ self._artifact,
180
+ index_field=self._columns_field,
181
+ cat_vectors=self._cat_vectors,
182
+ )
183
+ return self._artifact
184
+
185
+
186
+ class DataFrameCatManager(CatManager):
187
+ """Categorical manager for `DataFrame`."""
188
+
189
+ def __init__(
190
+ self,
191
+ df: pd.DataFrame | Artifact,
192
+ columns_field: FieldAttr = Feature.name,
193
+ columns_names: Iterable[str] | None = None,
194
+ categoricals: dict[str, FieldAttr] | None = None,
195
+ sources: dict[str, Record] | None = None,
196
+ index: Feature | None = None,
197
+ ) -> None:
198
+ self._non_validated = None
199
+ self._index = index
200
+ super().__init__(
201
+ dataset=df,
202
+ columns_field=columns_field,
203
+ categoricals=categoricals,
204
+ sources=sources,
205
+ )
206
+ if columns_names is None:
207
+ columns_names = []
208
+ if columns_field == Feature.name:
209
+ values = list(self._categoricals.keys()) # backward compat
210
+ self._cat_vectors["columns"] = CatVector(
211
+ values_getter=values,
212
+ field=self._columns_field,
213
+ key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
214
+ source=self._sources.get("columns"),
215
+ )
216
+ if isinstance(self._categoricals, dict): # backward compat
217
+ self._cat_vectors["columns"].validate()
218
+ else:
219
+ # NOTE: for var_index right now
220
+ self._cat_vectors["columns"] = CatVector(
221
+ values_getter=lambda: self._dataset.columns, # lambda ensures the inplace update
222
+ values_setter=lambda new_values: setattr(
223
+ self._dataset, "columns", pd.Index(new_values)
224
+ ),
225
+ field=self._columns_field,
226
+ key="columns",
227
+ source=self._sources.get("columns"),
228
+ )
229
+ for key, field in self._categoricals.items():
230
+ self._cat_vectors[key] = CatVector(
231
+ values_getter=lambda k=key: self._dataset[
232
+ k
233
+ ], # Capture key as default argument
234
+ values_setter=lambda new_values, k=key: self._dataset.__setitem__(
235
+ k, new_values
236
+ ),
237
+ field=field,
238
+ key=key,
239
+ source=self._sources.get(key),
240
+ feature=Feature.get(name=key),
241
+ )
242
+
243
+ def lookup(self, public: bool = False) -> CatLookup:
244
+ """Lookup categories.
245
+
246
+ Args:
247
+ public: If "public", the lookup is performed on the public reference.
248
+ """
249
+ return CatLookup(
250
+ categoricals=self._categoricals,
251
+ slots={"columns": self._columns_field},
252
+ public=public,
253
+ sources=self._sources,
254
+ )
255
+
256
+ def validate(self) -> bool:
257
+ """Validate variables and categorical observations."""
258
+ self._validate_category_error_messages = "" # reset the error messages
259
+
260
+ validated = True
261
+ for _, cat_vector in self._cat_vectors.items():
262
+ cat_vector.validate()
263
+ validated &= cat_vector.is_validated
264
+ self._is_validated = validated
265
+ self._non_validated = {} # so it's no longer None
266
+
267
+ if self._index is not None:
268
+ # cat_vector.validate() populates validated labels
269
+ # the index should become part of the feature set corresponding to the dataframe
270
+ self._cat_vectors["columns"].labels.insert(0, self._index) # type: ignore
271
+
272
+ return self._is_validated
273
+
274
+ def standardize(self, key: str) -> None:
275
+ """Replace synonyms with standardized values.
276
+
277
+ Modifies the input dataset inplace.
278
+
279
+ Args:
280
+ key: The key referencing the column in the DataFrame to standardize.
281
+ """
282
+ if self._artifact is not None:
283
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
284
+
285
+ if key == "all":
286
+ logger.warning(
287
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
288
+ )
289
+ for k in self.non_validated.keys():
290
+ self._cat_vectors[k].standardize()
291
+ else:
292
+ self._cat_vectors[key].standardize()
293
+
294
+ def add_new_from(self, key: str, **kwargs):
295
+ """Add validated & new categories.
296
+
297
+ Args:
298
+ key: The key referencing the slot in the DataFrame from which to draw terms.
299
+ **kwargs: Additional keyword arguments to pass to create new records
300
+ """
301
+ if len(kwargs) > 0 and key == "all":
302
+ raise ValueError("Cannot pass additional arguments to 'all' key!")
303
+ if key == "all":
304
+ logger.warning(
305
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
306
+ )
307
+ for k in self.non_validated.keys():
308
+ self._cat_vectors[k].add_new(**kwargs)
309
+ else:
310
+ self._cat_vectors[key].add_new(**kwargs)
311
+
312
+ @deprecated(
313
+ new_name="Run.filter(transform=context.run.transform, output_artifacts=None)"
314
+ )
315
+ def clean_up_failed_runs(self):
316
+ """Clean up previous failed runs that don't save any outputs."""
317
+ from lamindb.core._context import context
318
+
319
+ if context.run is not None:
320
+ Run.filter(transform=context.run.transform, output_artifacts=None).exclude(
321
+ uid=context.run.uid
322
+ ).delete()
323
+
324
+
325
+ class AnnDataCatManager(CatManager):
326
+ """Categorical manager for `AnnData`."""
327
+
328
+ def __init__(
329
+ self,
330
+ data: ad.AnnData | Artifact,
331
+ var_index: FieldAttr | None = None,
332
+ categoricals: dict[str, FieldAttr] | None = None,
333
+ obs_columns: FieldAttr = Feature.name,
334
+ sources: dict[str, Record] | None = None,
335
+ ) -> None:
336
+ if isinstance(var_index, str):
337
+ raise TypeError(
338
+ "var_index parameter has to be a field, e.g. Gene.ensembl_gene_id"
339
+ )
340
+
341
+ if not data_is_anndata(data):
342
+ raise TypeError("data has to be an AnnData object")
343
+
344
+ if "symbol" in str(var_index):
345
+ logger.warning(
346
+ "indexing datasets with gene symbols can be problematic: https://docs.lamin.ai/faq/symbol-mapping"
347
+ )
348
+
349
+ self._obs_fields = categoricals or {}
350
+ self._var_field = var_index
351
+ self._sources = sources or {}
352
+ super().__init__(
353
+ dataset=data,
354
+ categoricals=categoricals,
355
+ sources=self._sources,
356
+ columns_field=var_index,
357
+ )
358
+ self._adata = self._dataset
359
+ self._obs_df_curator = DataFrameCatManager(
360
+ df=self._adata.obs,
361
+ categoricals=self.categoricals,
362
+ columns_field=obs_columns,
363
+ sources=self._sources,
364
+ )
365
+ self._cat_vectors = self._obs_df_curator._cat_vectors.copy()
366
+ if var_index is not None:
367
+ self._cat_vectors["var_index"] = CatVector(
368
+ values_getter=lambda: self._adata.var.index,
369
+ values_setter=lambda new_values: setattr(
370
+ self._adata.var, "index", pd.Index(new_values)
371
+ ),
372
+ field=self._var_field,
373
+ key="var_index",
374
+ source=self._sources.get("var_index"),
375
+ )
376
+
377
+ @property
378
+ def var_index(self) -> FieldAttr:
379
+ """Return the registry field to validate variables index against."""
380
+ return self._var_field
381
+
382
+ @property
383
+ def categoricals(self) -> dict:
384
+ """Return the obs fields to validate against."""
385
+ return self._obs_fields
386
+
387
+ def lookup(self, public: bool = False) -> CatLookup:
388
+ """Lookup categories.
389
+
390
+ Args:
391
+ public: If "public", the lookup is performed on the public reference.
392
+ """
393
+ return CatLookup(
394
+ categoricals=self._obs_fields,
395
+ slots={"columns": self._columns_field, "var_index": self._var_field},
396
+ public=public,
397
+ sources=self._sources,
398
+ )
399
+
400
+ def add_new_from(self, key: str, **kwargs):
401
+ """Add validated & new categories.
402
+
403
+ Args:
404
+ key: The key referencing the slot in the DataFrame from which to draw terms.
405
+ **kwargs: Additional keyword arguments to pass to create new records
406
+ """
407
+ if key == "all":
408
+ logger.warning(
409
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
410
+ )
411
+ for k in self.non_validated.keys():
412
+ self._cat_vectors[k].add_new(**kwargs)
413
+ else:
414
+ self._cat_vectors[key].add_new(**kwargs)
415
+
416
+ @deprecated(new_name="add_new_from('var_index')")
417
+ def add_new_from_var_index(self, **kwargs):
418
+ """Update variable records.
419
+
420
+ Args:
421
+ **kwargs: Additional keyword arguments to pass to create new records.
422
+ """
423
+ self.add_new_from(key="var_index", **kwargs)
424
+
425
+ def validate(self) -> bool:
426
+ """Validate categories.
427
+
428
+ This method also registers the validated records in the current instance.
429
+
430
+ Returns:
431
+ Whether the AnnData object is validated.
432
+ """
433
+ self._validate_category_error_messages = "" # reset the error messages
434
+
435
+ validated = True
436
+ for _, cat_vector in self._cat_vectors.items():
437
+ cat_vector.validate()
438
+ validated &= cat_vector.is_validated
439
+
440
+ self._non_validated = {} # so it's no longer None
441
+ self._is_validated = validated
442
+ return self._is_validated
443
+
444
+ def standardize(self, key: str):
445
+ """Replace synonyms with standardized values.
446
+
447
+ Args:
448
+ key: The key referencing the slot in `adata.obs` from which to draw terms. Same as the key in `categoricals`.
449
+
450
+ - If "var_index", standardize the var.index.
451
+ - If "all", standardize all obs columns and var.index.
452
+
453
+ Inplace modification of the dataset.
454
+ """
455
+ if self._artifact is not None:
456
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
457
+ if key == "all":
458
+ logger.warning(
459
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
460
+ )
461
+ for k in self.non_validated.keys():
462
+ self._cat_vectors[k].standardize()
463
+ else:
464
+ self._cat_vectors[key].standardize()
465
+
466
+
467
+ @deprecated(new_name="MuDataCurator")
468
+ class MuDataCatManager(CatManager):
469
+ """Categorical manager for `MuData`."""
470
+
471
+ def __init__(
472
+ self,
473
+ mdata: MuData | Artifact,
474
+ var_index: dict[str, FieldAttr] | None = None,
475
+ categoricals: dict[str, FieldAttr] | None = None,
476
+ sources: dict[str, Record] | None = None,
477
+ ) -> None:
478
+ super().__init__(
479
+ dataset=mdata,
480
+ categoricals={},
481
+ sources=sources,
482
+ )
483
+ self._columns_field = (
484
+ var_index or {}
485
+ ) # this is for consistency with BaseCatManager
486
+ self._var_fields = var_index or {}
487
+ self._verify_modality(self._var_fields.keys())
488
+ self._obs_fields = self._parse_categoricals(categoricals or {})
489
+ self._modalities = set(self._var_fields.keys()) | set(self._obs_fields.keys())
490
+ self._obs_df_curator = None
491
+ if "obs" in self._modalities:
492
+ self._obs_df_curator = DataFrameCatManager(
493
+ df=self._dataset.obs,
494
+ columns_field=Feature.name,
495
+ categoricals=self._obs_fields.get("obs", {}),
496
+ sources=self._sources.get("obs"),
497
+ )
498
+ self._mod_adata_curators = {
499
+ modality: AnnDataCatManager(
500
+ data=self._dataset[modality],
501
+ var_index=var_index.get(modality),
502
+ categoricals=self._obs_fields.get(modality),
503
+ sources=self._sources.get(modality),
504
+ )
505
+ for modality in self._modalities
506
+ if modality != "obs"
507
+ }
508
+ self._non_validated = None
509
+
510
+ @property
511
+ def var_index(self) -> FieldAttr:
512
+ """Return the registry field to validate variables index against."""
513
+ return self._var_fields
514
+
515
+ @property
516
+ def categoricals(self) -> dict:
517
+ """Return the obs fields to validate against."""
518
+ return self._obs_fields
519
+
520
+ @property
521
+ def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
522
+ """Return the non-validated features and labels."""
523
+ if self._non_validated is None:
524
+ raise ValidationError("Please run validate() first!")
525
+ non_validated = {}
526
+ if (
527
+ self._obs_df_curator is not None
528
+ and len(self._obs_df_curator.non_validated) > 0
529
+ ):
530
+ non_validated["obs"] = self._obs_df_curator.non_validated
531
+ for modality, adata_curator in self._mod_adata_curators.items():
532
+ if len(adata_curator.non_validated) > 0:
533
+ non_validated[modality] = adata_curator.non_validated
534
+ self._non_validated = non_validated
535
+ return self._non_validated
536
+
537
+ def _verify_modality(self, modalities: Iterable[str]):
538
+ """Verify the modality exists."""
539
+ for modality in modalities:
540
+ if modality not in self._dataset.mod.keys():
541
+ raise ValidationError(f"modality '{modality}' does not exist!")
542
+
543
+ def _parse_categoricals(self, categoricals: dict[str, FieldAttr]) -> dict:
544
+ """Parse the categorical fields."""
545
+ prefixes = {f"{k}:" for k in self._dataset.mod.keys()}
546
+ obs_fields: dict[str, dict[str, FieldAttr]] = {}
547
+ for k, v in categoricals.items():
548
+ if k not in self._dataset.obs.columns:
549
+ raise ValidationError(f"column '{k}' does not exist in mdata.obs!")
550
+ if any(k.startswith(prefix) for prefix in prefixes):
551
+ modality, col = k.split(":")[0], k.split(":")[1]
552
+ if modality not in obs_fields.keys():
553
+ obs_fields[modality] = {}
554
+ obs_fields[modality][col] = v
555
+ else:
556
+ if "obs" not in obs_fields.keys():
557
+ obs_fields["obs"] = {}
558
+ obs_fields["obs"][k] = v
559
+ return obs_fields
560
+
561
+ def lookup(self, public: bool = False) -> CatLookup:
562
+ """Lookup categories.
563
+
564
+ Args:
565
+ public: Perform lookup on public source ontologies.
566
+ """
567
+ obs_fields = {}
568
+ for mod, fields in self._obs_fields.items():
569
+ for k, v in fields.items():
570
+ if k == "obs":
571
+ obs_fields[k] = v
572
+ else:
573
+ obs_fields[f"{mod}:{k}"] = v
574
+ return CatLookup(
575
+ categoricals=obs_fields,
576
+ slots={
577
+ **{f"{k}_var_index": v for k, v in self._var_fields.items()},
578
+ },
579
+ public=public,
580
+ sources=self._sources,
581
+ )
582
+
583
+ @deprecated(new_name="add_new_from('var_index')")
584
+ def add_new_from_var_index(self, modality: str, **kwargs):
585
+ """Update variable records.
586
+
587
+ Args:
588
+ modality: The modality name.
589
+ **kwargs: Additional keyword arguments to pass to create new records.
590
+ """
591
+ self._mod_adata_curators[modality].add_new_from(key="var_index", **kwargs)
592
+
593
+ def add_new_from(
594
+ self,
595
+ key: str,
596
+ modality: str | None = None,
597
+ **kwargs,
598
+ ):
599
+ """Add validated & new categories.
600
+
601
+ Args:
602
+ key: The key referencing the slot in the DataFrame.
603
+ modality: The modality name.
604
+ **kwargs: Additional keyword arguments to pass to create new records.
605
+ """
606
+ modality = modality or "obs"
607
+ if modality in self._mod_adata_curators:
608
+ adata_curator = self._mod_adata_curators[modality]
609
+ adata_curator.add_new_from(key=key, **kwargs)
610
+ if modality == "obs":
611
+ self._obs_df_curator.add_new_from(key=key, **kwargs)
612
+ if key == "var_index":
613
+ self._mod_adata_curators[modality].add_new_from(key=key, **kwargs)
614
+
615
+ def validate(self) -> bool:
616
+ """Validate categories."""
617
+ obs_validated = True
618
+ if "obs" in self._modalities:
619
+ logger.info('validating categoricals in "obs"...')
620
+ obs_validated &= self._obs_df_curator.validate()
621
+
622
+ mods_validated = True
623
+ for modality, adata_curator in self._mod_adata_curators.items():
624
+ logger.info(f'validating categoricals in modality "{modality}"...')
625
+ mods_validated &= adata_curator.validate()
626
+
627
+ self._non_validated = {} # so it's no longer None
628
+ self._is_validated = obs_validated & mods_validated
629
+ return self._is_validated
630
+
631
+ def standardize(self, key: str, modality: str | None = None):
632
+ """Replace synonyms with standardized values.
633
+
634
+ Args:
635
+ key: The key referencing the slot in the `MuData`.
636
+ modality: The modality name.
637
+
638
+ Inplace modification of the dataset.
639
+ """
640
+ if self._artifact is not None:
641
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
642
+ modality = modality or "obs"
643
+ if modality in self._mod_adata_curators:
644
+ adata_curator = self._mod_adata_curators[modality]
645
+ adata_curator.standardize(key=key)
646
+ if modality == "obs":
647
+ self._obs_df_curator.standardize(key=key)
648
+
649
+
650
+ def _maybe_curation_keys_not_present(nonval_keys: list[str], name: str):
651
+ if (n := len(nonval_keys)) > 0:
652
+ s = "s" if n > 1 else ""
653
+ are = "are" if n > 1 else "is"
654
+ raise ValidationError(
655
+ f"key{s} passed to {name} {are} not present: {colors.yellow(_format_values(nonval_keys))}"
656
+ )
657
+
658
+
659
+ @deprecated(new_name="SpatialDataCurator")
660
+ class SpatialDataCatManager(CatManager):
661
+ """Categorical manager for `SpatialData`."""
662
+
663
+ def __init__(
664
+ self,
665
+ sdata: Any,
666
+ var_index: dict[str, FieldAttr],
667
+ categoricals: dict[str, dict[str, FieldAttr]] | None = None,
668
+ sources: dict[str, dict[str, Record]] | None = None,
669
+ *,
670
+ sample_metadata_key: str | None = "sample",
671
+ ) -> None:
672
+ super().__init__(
673
+ dataset=sdata,
674
+ categoricals={},
675
+ sources=sources,
676
+ )
677
+ if isinstance(sdata, Artifact):
678
+ self._sdata = sdata.load()
679
+ else:
680
+ self._sdata = self._dataset
681
+ self._sample_metadata_key = sample_metadata_key
682
+ self._write_path = None
683
+ self._var_fields = var_index
684
+ self._verify_accessor_exists(self._var_fields.keys())
685
+ self._categoricals = categoricals
686
+ self._table_keys = set(self._var_fields.keys()) | set(
687
+ self._categoricals.keys() - {self._sample_metadata_key}
688
+ )
689
+ self._sample_df_curator = None
690
+ if self._sample_metadata_key is not None:
691
+ self._sample_metadata = self._sdata.get_attrs(
692
+ key=self._sample_metadata_key, return_as="df", flatten=True
693
+ )
694
+ self._is_validated = False
695
+
696
+ # Check validity of keys in categoricals
697
+ nonval_keys = []
698
+ for accessor, accessor_categoricals in self._categoricals.items():
699
+ if (
700
+ accessor == self._sample_metadata_key
701
+ and self._sample_metadata is not None
702
+ ):
703
+ for key in accessor_categoricals.keys():
704
+ if key not in self._sample_metadata.columns:
705
+ nonval_keys.append(key)
706
+ else:
707
+ for key in accessor_categoricals.keys():
708
+ if key not in self._sdata[accessor].obs.columns:
709
+ nonval_keys.append(key)
710
+
711
+ _maybe_curation_keys_not_present(nonval_keys, "categoricals")
712
+
713
+ # check validity of keys in sources
714
+ nonval_keys = []
715
+ for accessor, accessor_sources in self._sources.items():
716
+ if (
717
+ accessor == self._sample_metadata_key
718
+ and self._sample_metadata is not None
719
+ ):
720
+ columns = self._sample_metadata.columns
721
+ elif accessor != self._sample_metadata_key:
722
+ columns = self._sdata[accessor].obs.columns
723
+ else:
724
+ continue
725
+ for key in accessor_sources:
726
+ if key not in columns:
727
+ nonval_keys.append(key)
728
+ _maybe_curation_keys_not_present(nonval_keys, "sources")
729
+
730
+ # Set up sample level metadata and table Curator objects
731
+ if (
732
+ self._sample_metadata_key is not None
733
+ and self._sample_metadata_key in self._categoricals
734
+ ):
735
+ self._sample_df_curator = DataFrameCatManager(
736
+ df=self._sample_metadata,
737
+ columns_field=Feature.name,
738
+ categoricals=self._categoricals.get(self._sample_metadata_key, {}),
739
+ sources=self._sources.get(self._sample_metadata_key),
740
+ )
741
+ self._table_adata_curators = {
742
+ table: AnnDataCatManager(
743
+ data=self._sdata[table],
744
+ var_index=var_index.get(table),
745
+ categoricals=self._categoricals.get(table),
746
+ sources=self._sources.get(table),
747
+ )
748
+ for table in self._table_keys
749
+ }
750
+
751
+ self._non_validated = None
752
+
753
+ @property
754
+ def var_index(self) -> FieldAttr:
755
+ """Return the registry fields to validate variables indices against."""
756
+ return self._var_fields
757
+
758
+ @property
759
+ def categoricals(self) -> dict[str, dict[str, FieldAttr]]:
760
+ """Return the categorical keys and fields to validate against."""
761
+ return self._categoricals
762
+
763
+ @property
764
+ def non_validated(self) -> dict[str, dict[str, list[str]]]: # type: ignore
765
+ """Return the non-validated features and labels."""
766
+ if self._non_validated is None:
767
+ raise ValidationError("Please run validate() first!")
768
+ non_curated = {}
769
+ if len(self._sample_df_curator.non_validated) > 0:
770
+ non_curated[self._sample_metadata_key] = (
771
+ self._sample_df_curator.non_validated
772
+ )
773
+ for table, adata_curator in self._table_adata_curators.items():
774
+ if len(adata_curator.non_validated) > 0:
775
+ non_curated[table] = adata_curator.non_validated
776
+ return non_curated
777
+
778
+ def _verify_accessor_exists(self, accessors: Iterable[str]) -> None:
779
+ """Verify that the accessors exist (either a valid table or in attrs)."""
780
+ for acc in accessors:
781
+ is_present = False
782
+ try:
783
+ self._sdata.get_attrs(key=acc)
784
+ is_present = True
785
+ except KeyError:
786
+ if acc in self._sdata.tables.keys():
787
+ is_present = True
788
+ if not is_present:
789
+ raise ValidationError(f"Accessor '{acc}' does not exist!")
790
+
791
+ def lookup(self, public: bool = False) -> CatLookup:
792
+ """Look up categories.
793
+
794
+ Args:
795
+ public: Whether the lookup is performed on the public reference.
796
+ """
797
+ cat_values_dict = list(self.categoricals.values())[0]
798
+ return CatLookup(
799
+ categoricals=cat_values_dict,
800
+ slots={"accessors": cat_values_dict.keys()},
801
+ public=public,
802
+ sources=self._sources,
803
+ )
804
+
805
+ @deprecated(new_name="add_new_from('var_index')")
806
+ def add_new_from_var_index(self, table: str, **kwargs) -> None:
807
+ """Save new values from ``.var.index`` of table.
808
+
809
+ Args:
810
+ table: The table key.
811
+ **kwargs: Additional keyword arguments to pass to create new records.
812
+ """
813
+ if table in self.non_validated.keys():
814
+ self._table_adata_curators[table].add_new_from(key="var_index", **kwargs)
815
+
816
+ def add_new_from(
817
+ self,
818
+ key: str,
819
+ accessor: str | None = None,
820
+ **kwargs,
821
+ ) -> None:
822
+ """Save new values of categorical from sample level metadata or table.
823
+
824
+ Args:
825
+ key: The key referencing the slot in the DataFrame.
826
+ accessor: The accessor key such as 'sample' or 'table x'.
827
+ **kwargs: Additional keyword arguments to pass to create new records.
828
+ """
829
+ if accessor in self.non_validated.keys():
830
+ if accessor in self._table_adata_curators:
831
+ adata_curator = self._table_adata_curators[accessor]
832
+ adata_curator.add_new_from(key=key, **kwargs)
833
+ if accessor == self._sample_metadata_key:
834
+ self._sample_df_curator.add_new_from(key=key, **kwargs)
835
+
836
+ if key == "var_index":
837
+ self._table_adata_curators[accessor].add_new_from(key=key, **kwargs)
838
+
839
+ def standardize(self, key: str, accessor: str | None = None) -> None:
840
+ """Replace synonyms with canonical values.
841
+
842
+ Modifies the dataset inplace.
843
+
844
+ Args:
845
+ key: The key referencing the slot in the table or sample metadata.
846
+ accessor: The accessor key such as 'sample_key' or 'table_key'.
847
+ """
848
+ if len(self.non_validated) == 0:
849
+ logger.warning("values are already standardized")
850
+ return
851
+ if self._artifact is not None:
852
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
853
+
854
+ if accessor == self._sample_metadata_key:
855
+ if key not in self._sample_metadata.columns:
856
+ raise ValueError(f"key '{key}' not present in '{accessor}'!")
857
+ else:
858
+ if (
859
+ key == "var_index" and self._sdata.tables[accessor].var.index is None
860
+ ) or (
861
+ key != "var_index"
862
+ and key not in self._sdata.tables[accessor].obs.columns
863
+ ):
864
+ raise ValueError(f"key '{key}' not present in '{accessor}'!")
865
+
866
+ if accessor in self._table_adata_curators.keys():
867
+ adata_curator = self._table_adata_curators[accessor]
868
+ adata_curator.standardize(key)
869
+ if accessor == self._sample_metadata_key:
870
+ self._sample_df_curator.standardize(key)
871
+
872
+ def validate(self) -> bool:
873
+ """Validate variables and categorical observations.
874
+
875
+ This method also registers the validated records in the current instance:
876
+ - from public sources
877
+
878
+ Returns:
879
+ Whether the SpatialData object is validated.
880
+ """
881
+ # add all validated records to the current instance
882
+ sample_validated = True
883
+ if self._sample_df_curator:
884
+ logger.info(f"validating categoricals of '{self._sample_metadata_key}' ...")
885
+ sample_validated &= self._sample_df_curator.validate()
886
+
887
+ mods_validated = True
888
+ for table, adata_curator in self._table_adata_curators.items():
889
+ logger.info(f"validating categoricals of table '{table}' ...")
890
+ mods_validated &= adata_curator.validate()
891
+
892
+ self._non_validated = {} # so it's no longer None
893
+ self._is_validated = sample_validated & mods_validated
894
+ return self._is_validated
895
+
896
+ def save_artifact(
897
+ self,
898
+ *,
899
+ key: str | None = None,
900
+ description: str | None = None,
901
+ revises: Artifact | None = None,
902
+ run: Run | None = None,
903
+ ) -> Artifact:
904
+ """Save the validated SpatialData store and metadata.
905
+
906
+ Args:
907
+ description: A description of the dataset.
908
+ key: A path-like key to reference artifact in default storage,
909
+ e.g., `"myartifact.zarr"`. Artifacts with the same key form a version family.
910
+ revises: Previous version of the artifact. Triggers a revision.
911
+ run: The run that creates the artifact.
912
+
913
+ Returns:
914
+ A saved artifact record.
915
+ """
916
+ if not self._is_validated:
917
+ self.validate()
918
+ if not self._is_validated:
919
+ raise ValidationError("Dataset does not validate. Please curate.")
920
+
921
+ self._artifact = Artifact.from_spatialdata(
922
+ self._dataset, key=key, description=description, revises=revises, run=run
923
+ ).save()
924
+ return legacy_annotate_artifact(
925
+ self._artifact,
926
+ index_field=self.var_index,
927
+ sample_metadata_key=self._sample_metadata_key,
928
+ )
929
+
930
+
931
+ class TiledbsomaCatManager(CatManager):
932
+ """Categorical manager for `tiledbsoma.Experiment`."""
933
+
934
+ def __init__(
935
+ self,
936
+ experiment_uri: UPathStr | Artifact,
937
+ var_index: dict[str, tuple[str, FieldAttr]],
938
+ categoricals: dict[str, FieldAttr] | None = None,
939
+ obs_columns: FieldAttr = Feature.name,
940
+ sources: dict[str, Record] | None = None,
941
+ ):
942
+ self._obs_fields = categoricals or {}
943
+ self._var_fields = var_index
944
+ self._columns_field = obs_columns
945
+ if isinstance(experiment_uri, Artifact):
946
+ self._dataset = experiment_uri.path
947
+ self._artifact = experiment_uri
948
+ else:
949
+ self._dataset = UPath(experiment_uri)
950
+ self._artifact = None
951
+ self._sources = sources or {}
952
+
953
+ self._is_validated: bool | None = False
954
+ self._non_validated_values: dict[str, list] | None = None
955
+ self._validated_values: dict[str, list] = {}
956
+ # filled by _check_save_keys
957
+ self._n_obs: int | None = None
958
+ self._valid_obs_keys: list[str] | None = None
959
+ self._obs_pa_schema: pa.lib.Schema | None = (
960
+ None # this is needed to create the obs feature set
961
+ )
962
+ self._valid_var_keys: list[str] | None = None
963
+ self._var_fields_flat: dict[str, FieldAttr] | None = None
964
+ self._check_save_keys()
965
+
966
+ # check that the provided keys in var_index and categoricals are available in the store
967
+ # and save features
968
+ def _check_save_keys(self):
969
+ from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
970
+
971
+ with _open_tiledbsoma(self._dataset, mode="r") as experiment:
972
+ experiment_obs = experiment.obs
973
+ self._n_obs = len(experiment_obs)
974
+ self._obs_pa_schema = experiment_obs.schema
975
+ valid_obs_keys = [
976
+ k for k in self._obs_pa_schema.names if k != "soma_joinid"
977
+ ]
978
+ self._valid_obs_keys = valid_obs_keys
979
+
980
+ valid_var_keys = []
981
+ ms_list = []
982
+ for ms in experiment.ms.keys():
983
+ ms_list.append(ms)
984
+ var_ms = experiment.ms[ms].var
985
+ valid_var_keys += [
986
+ f"{ms}__{k}" for k in var_ms.keys() if k != "soma_joinid"
987
+ ]
988
+ self._valid_var_keys = valid_var_keys
989
+
990
+ # check validity of keys in categoricals
991
+ nonval_keys = []
992
+ for obs_key in self._obs_fields.keys():
993
+ if obs_key not in valid_obs_keys:
994
+ nonval_keys.append(obs_key)
995
+ _maybe_curation_keys_not_present(nonval_keys, "categoricals")
996
+
997
+ # check validity of keys in var_index
998
+ self._var_fields_flat = {}
999
+ nonval_keys = []
1000
+ for ms_key in self._var_fields.keys():
1001
+ var_key, var_field = self._var_fields[ms_key]
1002
+ var_key_flat = f"{ms_key}__{var_key}"
1003
+ if var_key_flat not in valid_var_keys:
1004
+ nonval_keys.append(f"({ms_key}, {var_key})")
1005
+ else:
1006
+ self._var_fields_flat[var_key_flat] = var_field
1007
+ _maybe_curation_keys_not_present(nonval_keys, "var_index")
1008
+
1009
+ # check validity of keys in sources
1010
+ valid_arg_keys = valid_obs_keys + valid_var_keys + ["columns"]
1011
+ nonval_keys = []
1012
+ for arg_key in self._sources.keys():
1013
+ if arg_key not in valid_arg_keys:
1014
+ nonval_keys.append(arg_key)
1015
+ _maybe_curation_keys_not_present(nonval_keys, "sources")
1016
+
1017
+ # register obs columns' names
1018
+ register_columns = list(self._obs_fields.keys())
1019
+ # register categorical keys as features
1020
+ cat_vector = CatVector(
1021
+ values_getter=register_columns,
1022
+ field=self._columns_field,
1023
+ key="columns",
1024
+ source=self._sources.get("columns"),
1025
+ )
1026
+ cat_vector.add_new()
1027
+
1028
+ def validate(self):
1029
+ """Validate categories."""
1030
+ from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
1031
+
1032
+ validated = True
1033
+ self._non_validated_values = {}
1034
+ with _open_tiledbsoma(self._dataset, mode="r") as experiment:
1035
+ for ms, (key, field) in self._var_fields.items():
1036
+ var_ms = experiment.ms[ms].var
1037
+ var_ms_key = f"{ms}__{key}"
1038
+ # it was already validated and cached
1039
+ if var_ms_key in self._validated_values:
1040
+ continue
1041
+ var_ms_values = (
1042
+ var_ms.read(column_names=[key]).concat()[key].to_pylist()
1043
+ )
1044
+ cat_vector = CatVector(
1045
+ values_getter=var_ms_values,
1046
+ field=field,
1047
+ key=var_ms_key,
1048
+ source=self._sources.get(var_ms_key),
1049
+ )
1050
+ cat_vector.validate()
1051
+ non_val = cat_vector._non_validated
1052
+ if len(non_val) > 0:
1053
+ validated = False
1054
+ self._non_validated_values[var_ms_key] = non_val
1055
+ else:
1056
+ self._validated_values[var_ms_key] = var_ms_values
1057
+
1058
+ obs = experiment.obs
1059
+ for key, field in self._obs_fields.items():
1060
+ # already validated and cached
1061
+ if key in self._validated_values:
1062
+ continue
1063
+ values = pa.compute.unique(
1064
+ obs.read(column_names=[key]).concat()[key]
1065
+ ).to_pylist()
1066
+ cat_vector = CatVector(
1067
+ values_getter=values,
1068
+ field=field,
1069
+ key=key,
1070
+ source=self._sources.get(key),
1071
+ )
1072
+ cat_vector.validate()
1073
+ non_val = cat_vector._non_validated
1074
+ if len(non_val) > 0:
1075
+ validated = False
1076
+ self._non_validated_values[key] = non_val
1077
+ else:
1078
+ self._validated_values[key] = values
1079
+ self._is_validated = validated
1080
+ return self._is_validated
1081
+
1082
+ def _non_validated_values_field(self, key: str) -> tuple[list, FieldAttr]:
1083
+ assert self._non_validated_values is not None # noqa: S101
1084
+
1085
+ if key in self._valid_obs_keys:
1086
+ field = self._obs_fields[key]
1087
+ elif key in self._valid_var_keys:
1088
+ ms = key.partition("__")[0]
1089
+ field = self._var_fields[ms][1]
1090
+ else:
1091
+ raise KeyError(f"key {key} is invalid!")
1092
+ values = self._non_validated_values.get(key, [])
1093
+ return values, field
1094
+
1095
+ def add_new_from(self, key: str, **kwargs) -> None:
1096
+ """Add validated & new categories.
1097
+
1098
+ Args:
1099
+ key: The key referencing the slot in the `tiledbsoma` store.
1100
+ It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
1101
+ or a column name in `.obs`.
1102
+ """
1103
+ if self._non_validated_values is None:
1104
+ raise ValidationError("Run .validate() first.")
1105
+ if key == "all":
1106
+ keys = list(self._non_validated_values.keys())
1107
+ else:
1108
+ avail_keys = list(
1109
+ chain(self._non_validated_values.keys(), self._validated_values.keys())
1110
+ )
1111
+ if key not in avail_keys:
1112
+ raise KeyError(
1113
+ f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
1114
+ )
1115
+ keys = [key]
1116
+ for k in keys:
1117
+ values, field = self._non_validated_values_field(k)
1118
+ if len(values) == 0:
1119
+ continue
1120
+ cat_vector = CatVector(
1121
+ values_getter=values,
1122
+ field=field,
1123
+ key=k,
1124
+ source=self._sources.get(k),
1125
+ )
1126
+ cat_vector.add_new()
1127
+ # update non-validated values list but keep the key there
1128
+ # it will be removed by .validate()
1129
+ if k in self._non_validated_values:
1130
+ self._non_validated_values[k] = []
1131
+
1132
+ @property
1133
+ def non_validated(self) -> dict[str, list]:
1134
+ """Return the non-validated features and labels."""
1135
+ non_val = {k: v for k, v in self._non_validated_values.items() if v != []}
1136
+ return non_val
1137
+
1138
+ @property
1139
+ def var_index(self) -> dict[str, FieldAttr]:
1140
+ """Return the registry fields with flattened keys to validate variables indices against."""
1141
+ return self._var_fields_flat
1142
+
1143
+ @property
1144
+ def categoricals(self) -> dict[str, FieldAttr]:
1145
+ """Return the obs fields to validate against."""
1146
+ return self._obs_fields
1147
+
1148
+ def lookup(self, public: bool = False) -> CatLookup:
1149
+ """Lookup categories.
1150
+
1151
+ Args:
1152
+ public: If "public", the lookup is performed on the public reference.
1153
+ """
1154
+ return CatLookup(
1155
+ categoricals=self._obs_fields,
1156
+ slots={"columns": self._columns_field, **self._var_fields_flat},
1157
+ public=public,
1158
+ sources=self._sources,
1159
+ )
1160
+
1161
+ def standardize(self, key: str):
1162
+ """Replace synonyms with standardized values.
1163
+
1164
+ Modifies the dataset inplace.
1165
+
1166
+ Args:
1167
+ key: The key referencing the slot in the `tiledbsoma` store.
1168
+ It should be `'{measurement name}__{column name in .var}'` for columns in `.var`
1169
+ or a column name in `.obs`.
1170
+ """
1171
+ if len(self.non_validated) == 0:
1172
+ logger.warning("values are already standardized")
1173
+ return
1174
+ avail_keys = list(self._non_validated_values.keys())
1175
+ if key == "all":
1176
+ keys = avail_keys
1177
+ else:
1178
+ if key not in avail_keys:
1179
+ raise KeyError(
1180
+ f"'{key!r}' is not a valid key, available keys are: {_format_values(avail_keys + ['all'])}!"
1181
+ )
1182
+ keys = [key]
1183
+
1184
+ for k in keys:
1185
+ values, field = self._non_validated_values_field(k)
1186
+ if len(values) == 0:
1187
+ continue
1188
+ if k in self._valid_var_keys:
1189
+ ms, _, slot_key = k.partition("__")
1190
+ slot = lambda experiment: experiment.ms[ms].var # noqa: B023
1191
+ else:
1192
+ slot = lambda experiment: experiment.obs
1193
+ slot_key = k
1194
+ cat_vector = CatVector(
1195
+ values_getter=values,
1196
+ field=field,
1197
+ key=k,
1198
+ source=self._sources.get(k),
1199
+ )
1200
+ cat_vector.validate()
1201
+ syn_mapper = cat_vector._synonyms
1202
+ if (n_syn_mapper := len(syn_mapper)) == 0:
1203
+ continue
1204
+
1205
+ from lamindb.core.storage._tiledbsoma import _open_tiledbsoma
1206
+
1207
+ with _open_tiledbsoma(self._dataset, mode="r") as experiment:
1208
+ value_filter = f"{slot_key} in {list(syn_mapper.keys())}"
1209
+ table = slot(experiment).read(value_filter=value_filter).concat()
1210
+
1211
+ if len(table) == 0:
1212
+ continue
1213
+
1214
+ df = table.to_pandas()
1215
+ # map values
1216
+ df[slot_key] = df[slot_key].map(
1217
+ lambda val: syn_mapper.get(val, val) # noqa
1218
+ )
1219
+ # write the mapped values
1220
+ with _open_tiledbsoma(self._dataset, mode="w") as experiment:
1221
+ slot(experiment).write(pa.Table.from_pandas(df, schema=table.schema))
1222
+ # update non_validated dict
1223
+ non_val_k = [
1224
+ nv for nv in self._non_validated_values[k] if nv not in syn_mapper
1225
+ ]
1226
+ self._non_validated_values[k] = non_val_k
1227
+
1228
+ syn_mapper_print = _format_values(
1229
+ [f'"{m_k}" → "{m_v}"' for m_k, m_v in syn_mapper.items()], sep=""
1230
+ )
1231
+ s = "s" if n_syn_mapper > 1 else ""
1232
+ logger.success(
1233
+ f'standardized {n_syn_mapper} synonym{s} in "{k}": {colors.green(syn_mapper_print)}'
1234
+ )
1235
+
1236
+ def save_artifact(
1237
+ self,
1238
+ *,
1239
+ key: str | None = None,
1240
+ description: str | None = None,
1241
+ revises: Artifact | None = None,
1242
+ run: Run | None = None,
1243
+ ) -> Artifact:
1244
+ """Save the validated `tiledbsoma` store and metadata.
1245
+
1246
+ Args:
1247
+ description: A description of the ``tiledbsoma`` store.
1248
+ key: A path-like key to reference artifact in default storage,
1249
+ e.g., `"myfolder/mystore.tiledbsoma"`. Artifacts with the same key form a version family.
1250
+ revises: Previous version of the artifact. Triggers a revision.
1251
+ run: The run that creates the artifact.
1252
+
1253
+ Returns:
1254
+ A saved artifact record.
1255
+ """
1256
+ if not self._is_validated:
1257
+ self.validate()
1258
+ if not self._is_validated:
1259
+ raise ValidationError("Dataset does not validate. Please curate.")
1260
+
1261
+ if self._artifact is None:
1262
+ artifact = Artifact(
1263
+ self._dataset,
1264
+ description=description,
1265
+ key=key,
1266
+ revises=revises,
1267
+ run=run,
1268
+ )
1269
+ artifact.n_observations = self._n_obs
1270
+ artifact.otype = "tiledbsoma"
1271
+ artifact.save()
1272
+ else:
1273
+ artifact = self._artifact
1274
+
1275
+ feature_sets = {}
1276
+ if len(self._obs_fields) > 0:
1277
+ empty_dict = {field.name: [] for field in self._obs_pa_schema} # type: ignore
1278
+ mock_df = pa.Table.from_pydict(
1279
+ empty_dict, schema=self._obs_pa_schema
1280
+ ).to_pandas()
1281
+ # in parallel to https://github.com/laminlabs/lamindb/blob/2a1709990b5736b480c6de49c0ada47fafc8b18d/lamindb/core/_feature_manager.py#L549-L554
1282
+ feature_sets["obs"] = Schema.from_df(
1283
+ df=mock_df,
1284
+ field=self._columns_field,
1285
+ mute=True,
1286
+ )
1287
+ for ms in self._var_fields:
1288
+ var_key, var_field = self._var_fields[ms]
1289
+ feature_sets[f"{ms}__var"] = Schema.from_values(
1290
+ values=self._validated_values[f"{ms}__{var_key}"],
1291
+ field=var_field,
1292
+ raise_validation_error=False,
1293
+ )
1294
+ artifact._staged_feature_sets = feature_sets
1295
+
1296
+ feature_ref_is_name = _ref_is_name(self._columns_field)
1297
+ features = Feature.lookup().dict()
1298
+ for key, field in self._obs_fields.items():
1299
+ feature = features.get(key)
1300
+ registry = field.field.model
1301
+ labels = registry.from_values(
1302
+ values=self._validated_values[key],
1303
+ field=field,
1304
+ )
1305
+ if len(labels) == 0:
1306
+ continue
1307
+ if hasattr(registry, "_name_field"):
1308
+ label_ref_is_name = field.field.name == registry._name_field
1309
+ add_labels(
1310
+ artifact,
1311
+ records=labels,
1312
+ feature=feature,
1313
+ feature_ref_is_name=feature_ref_is_name,
1314
+ label_ref_is_name=label_ref_is_name,
1315
+ from_curator=True,
1316
+ )
1317
+
1318
+ return artifact.save()
1319
+
1320
+
1321
+ class CellxGeneAnnDataCatManager(AnnDataCatManager):
1322
+ """Categorical manager for `AnnData` respecting the CELLxGENE schema.
1323
+
1324
+ This will be superceded by a schema-based curation flow.
1325
+ """
1326
+
1327
+ cxg_categoricals_defaults = {
1328
+ "cell_type": "unknown",
1329
+ "development_stage": "unknown",
1330
+ "disease": "normal",
1331
+ "donor_id": "unknown",
1332
+ "self_reported_ethnicity": "unknown",
1333
+ "sex": "unknown",
1334
+ "suspension_type": "cell",
1335
+ "tissue_type": "tissue",
1336
+ }
1337
+
1338
+ def __init__(
1339
+ self,
1340
+ adata: ad.AnnData,
1341
+ categoricals: dict[str, FieldAttr] | None = None,
1342
+ *,
1343
+ schema_version: Literal["4.0.0", "5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
1344
+ defaults: dict[str, str] = None,
1345
+ extra_sources: dict[str, Record] = None,
1346
+ ) -> None:
1347
+ """CELLxGENE schema curator.
1348
+
1349
+ Args:
1350
+ adata: Path to or AnnData object to curate against the CELLxGENE schema.
1351
+ categoricals: A dictionary mapping ``.obs.columns`` to a registry field.
1352
+ The CELLxGENE Curator maps against the required CELLxGENE fields by default.
1353
+ schema_version: The CELLxGENE schema version to curate against.
1354
+ defaults: Default values that are set if columns or column values are missing.
1355
+ extra_sources: A dictionary mapping ``.obs.columns`` to Source records.
1356
+ These extra sources are joined with the CELLxGENE fixed sources.
1357
+ Use this parameter when subclassing.
1358
+ """
1359
+ import bionty as bt
1360
+
1361
+ from ._cellxgene_schemas import (
1362
+ _add_defaults_to_obs,
1363
+ _create_sources,
1364
+ _init_categoricals_additional_values,
1365
+ _restrict_obs_fields,
1366
+ )
1367
+
1368
+ # Add defaults first to ensure that we fetch valid sources
1369
+ if defaults:
1370
+ _add_defaults_to_obs(adata.obs, defaults)
1371
+
1372
+ # Filter categoricals based on what's present in adata
1373
+ if categoricals is None:
1374
+ categoricals = self._get_cxg_categoricals()
1375
+ categoricals = _restrict_obs_fields(adata.obs, categoricals)
1376
+
1377
+ # Configure sources
1378
+ organism: Literal["human", "mouse"] = "human"
1379
+ sources = _create_sources(categoricals, schema_version, organism)
1380
+ self.schema_version = schema_version
1381
+ self.schema_reference = f"https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/{schema_version}/schema.md"
1382
+ # These sources are not a part of the cellxgene schema but rather passed through.
1383
+ # This is useful when other Curators extend the CELLxGENE curator
1384
+ if extra_sources:
1385
+ sources = sources | extra_sources
1386
+
1387
+ _init_categoricals_additional_values()
1388
+
1389
+ super().__init__(
1390
+ data=adata,
1391
+ var_index=bt.Gene.ensembl_gene_id,
1392
+ categoricals=categoricals,
1393
+ sources=sources,
1394
+ )
1395
+
1396
+ @classmethod
1397
+ def _get_cxg_categoricals(cls) -> dict[str, FieldAttr]:
1398
+ """Returns the CELLxGENE schema mapped fields."""
1399
+ from ._cellxgene_schemas import _get_cxg_categoricals
1400
+
1401
+ return _get_cxg_categoricals()
1402
+
1403
+ def validate(self) -> bool:
1404
+ """Validates the AnnData object against most cellxgene requirements."""
1405
+ from ._cellxgene_schemas import RESERVED_NAMES
1406
+
1407
+ # Verify that all required obs columns are present
1408
+ required_columns = list(self.cxg_categoricals_defaults.keys()) + ["donor_id"]
1409
+ missing_obs_fields = [
1410
+ name
1411
+ for name in required_columns
1412
+ if name not in self._adata.obs.columns
1413
+ and f"{name}_ontology_term_id" not in self._adata.obs.columns
1414
+ ]
1415
+ if len(missing_obs_fields) > 0:
1416
+ logger.error(
1417
+ f"missing required obs columns {_format_values(missing_obs_fields)}\n"
1418
+ " → consider initializing a Curate object with `defaults=cxg.CellxGeneAnnDataCatManager.cxg_categoricals_defaults` to automatically add these columns with default values"
1419
+ )
1420
+ return False
1421
+
1422
+ # Verify that no cellxgene reserved names are present
1423
+ matched_columns = [
1424
+ column for column in self._adata.obs.columns if column in RESERVED_NAMES
1425
+ ]
1426
+ if len(matched_columns) > 0:
1427
+ raise ValueError(
1428
+ f"AnnData object must not contain obs columns {matched_columns} which are"
1429
+ " reserved from previous schema versions."
1430
+ )
1431
+
1432
+ return super().validate()
1433
+
1434
+ def to_cellxgene_anndata(
1435
+ self, is_primary_data: bool, title: str | None = None
1436
+ ) -> ad.AnnData:
1437
+ """Converts the AnnData object to the cellxgene-schema input format.
1438
+
1439
+ cellxgene expects the obs fields to be {entity}_ontology_id fields and has many further requirements which are
1440
+ documented here: https://github.com/chanzuckerberg/single-cell-curation/tree/main/schema.
1441
+ This function checks for most but not all requirements of the CELLxGENE schema.
1442
+ If you want to ensure that it fully adheres to the CELLxGENE schema, run `cellxgene-schema` on the AnnData object.
1443
+
1444
+ Args:
1445
+ is_primary_data: Whether the measured data is primary data or not.
1446
+ title: Title of the AnnData object. Commonly the name of the publication.
1447
+
1448
+ Returns:
1449
+ An AnnData object which adheres to the cellxgene-schema.
1450
+ """
1451
+
1452
+ def _convert_name_to_ontology_id(values: pd.Series, field: FieldAttr):
1453
+ """Converts a column that stores a name into a column that stores the ontology id.
1454
+
1455
+ cellxgene expects the obs columns to be {entity}_ontology_id columns and disallows {entity} columns.
1456
+ """
1457
+ field_name = field.field.name
1458
+ assert field_name == "name" # noqa: S101
1459
+ cols = ["name", "ontology_id"]
1460
+ registry = field.field.model
1461
+
1462
+ if hasattr(registry, "ontology_id"):
1463
+ validated_records = registry.filter(**{f"{field_name}__in": values})
1464
+ mapper = (
1465
+ pd.DataFrame(validated_records.values_list(*cols))
1466
+ .set_index(0)
1467
+ .to_dict()[1]
1468
+ )
1469
+ return values.map(mapper)
1470
+
1471
+ # Create a copy since we modify the AnnData object extensively
1472
+ adata_cxg = self._adata.copy()
1473
+
1474
+ # cellxgene requires an embedding
1475
+ embedding_pattern = r"^[a-zA-Z][a-zA-Z0-9_.-]*$"
1476
+ exclude_key = "spatial"
1477
+ matching_keys = [
1478
+ key
1479
+ for key in adata_cxg.obsm.keys()
1480
+ if re.match(embedding_pattern, key) and key != exclude_key
1481
+ ]
1482
+ if len(matching_keys) == 0:
1483
+ raise ValueError(
1484
+ "Unable to find an embedding key. Please calculate an embedding."
1485
+ )
1486
+
1487
+ # convert name column to ontology_term_id column
1488
+ for column in adata_cxg.obs.columns:
1489
+ if column in self.categoricals and not column.endswith("_ontology_term_id"):
1490
+ mapped_column = _convert_name_to_ontology_id(
1491
+ adata_cxg.obs[column], field=self.categoricals.get(column)
1492
+ )
1493
+ if mapped_column is not None:
1494
+ adata_cxg.obs[f"{column}_ontology_term_id"] = mapped_column
1495
+
1496
+ # drop the name columns for ontologies. cellxgene does not allow them.
1497
+ drop_columns = [
1498
+ i
1499
+ for i in adata_cxg.obs.columns
1500
+ if f"{i}_ontology_term_id" in adata_cxg.obs.columns
1501
+ ]
1502
+ adata_cxg.obs.drop(columns=drop_columns, inplace=True)
1503
+
1504
+ # Add cellxgene metadata to AnnData object
1505
+ if "is_primary_data" not in adata_cxg.obs.columns:
1506
+ adata_cxg.obs["is_primary_data"] = is_primary_data
1507
+ if "feature_is_filtered" not in adata_cxg.var.columns:
1508
+ logger.warn(
1509
+ "column 'feature_is_filtered' not present in var. Setting to default"
1510
+ " value of False."
1511
+ )
1512
+ adata_cxg.var["feature_is_filtered"] = False
1513
+ if title is None:
1514
+ raise ValueError("please pass a title!")
1515
+ else:
1516
+ adata_cxg.uns["title"] = title
1517
+ adata_cxg.uns["cxg_lamin_schema_reference"] = self.schema_reference
1518
+ adata_cxg.uns["cxg_lamin_schema_version"] = self.schema_version
1519
+
1520
+ return adata_cxg
1521
+
1522
+
1523
+ class ValueUnit:
1524
+ """Base class for handling value-unit combinations."""
1525
+
1526
+ @staticmethod
1527
+ def parse_value_unit(value: str, is_dose: bool = True) -> tuple[str, str] | None:
1528
+ """Parse a string containing a value and unit into a tuple."""
1529
+ if not isinstance(value, str) or not value.strip():
1530
+ return None
1531
+
1532
+ value = str(value).strip()
1533
+ match = re.match(r"^(\d*\.?\d{0,1})\s*([a-zA-ZμµΜ]+)$", value)
1534
+
1535
+ if not match:
1536
+ raise ValueError(
1537
+ f"Invalid format: {value}. Expected format: number with max 1 decimal place + unit"
1538
+ )
1539
+
1540
+ number, unit = match.groups()
1541
+ formatted_number = f"{float(number):.1f}"
1542
+
1543
+ if is_dose:
1544
+ standardized_unit = DoseHandler.standardize_unit(unit)
1545
+ if not DoseHandler.validate_unit(standardized_unit):
1546
+ raise ValueError(
1547
+ f"Invalid dose unit: {unit}. Must be convertible to one of: nM, μM, mM, M"
1548
+ )
1549
+ else:
1550
+ standardized_unit = TimeHandler.standardize_unit(unit)
1551
+ if not TimeHandler.validate_unit(standardized_unit):
1552
+ raise ValueError(
1553
+ f"Invalid time unit: {unit}. Must be convertible to one of: h, m, s, d, y"
1554
+ )
1555
+
1556
+ return formatted_number, standardized_unit
1557
+
1558
+
1559
+ class DoseHandler:
1560
+ """Handler for dose-related operations."""
1561
+
1562
+ VALID_UNITS = {"nM", "μM", "µM", "mM", "M"}
1563
+ UNIT_MAP = {
1564
+ "nm": "nM",
1565
+ "NM": "nM",
1566
+ "um": "μM",
1567
+ "UM": "μM",
1568
+ "μm": "μM",
1569
+ "μM": "μM",
1570
+ "µm": "μM",
1571
+ "µM": "μM",
1572
+ "mm": "mM",
1573
+ "MM": "mM",
1574
+ "m": "M",
1575
+ "M": "M",
1576
+ }
1577
+
1578
+ @classmethod
1579
+ def validate_unit(cls, unit: str) -> bool:
1580
+ """Validate if the dose unit is acceptable."""
1581
+ return unit in cls.VALID_UNITS
1582
+
1583
+ @classmethod
1584
+ def standardize_unit(cls, unit: str) -> str:
1585
+ """Standardize dose unit to standard formats."""
1586
+ return cls.UNIT_MAP.get(unit, unit)
1587
+
1588
+ @classmethod
1589
+ def validate_values(cls, values: pd.Series) -> list[str]:
1590
+ """Validate pert_dose values with strict case checking."""
1591
+ errors = []
1592
+
1593
+ for idx, value in values.items():
1594
+ if pd.isna(value):
1595
+ continue
1596
+
1597
+ if isinstance(value, (int, float)):
1598
+ errors.append(
1599
+ f"Row {idx} - Missing unit for dose: {value}. Must include a unit (nM, μM, mM, M)"
1600
+ )
1601
+ continue
1602
+
1603
+ try:
1604
+ ValueUnit.parse_value_unit(value, is_dose=True)
1605
+ except ValueError as e:
1606
+ errors.append(f"Row {idx} - {str(e)}")
1607
+
1608
+ return errors
1609
+
1610
+
1611
+ class TimeHandler:
1612
+ """Handler for time-related operations."""
1613
+
1614
+ VALID_UNITS = {"h", "m", "s", "d", "y"}
1615
+
1616
+ @classmethod
1617
+ def validate_unit(cls, unit: str) -> bool:
1618
+ """Validate if the time unit is acceptable."""
1619
+ return unit == unit.lower() and unit in cls.VALID_UNITS
1620
+
1621
+ @classmethod
1622
+ def standardize_unit(cls, unit: str) -> str:
1623
+ """Standardize time unit to standard formats."""
1624
+ if unit.startswith("hr"):
1625
+ return "h"
1626
+ elif unit.startswith("min"):
1627
+ return "m"
1628
+ elif unit.startswith("sec"):
1629
+ return "s"
1630
+ return unit[0].lower()
1631
+
1632
+ @classmethod
1633
+ def validate_values(cls, values: pd.Series) -> list[str]:
1634
+ """Validate pert_time values."""
1635
+ errors = []
1636
+
1637
+ for idx, value in values.items():
1638
+ if pd.isna(value):
1639
+ continue
1640
+
1641
+ if isinstance(value, (int, float)):
1642
+ errors.append(
1643
+ f"Row {idx} - Missing unit for time: {value}. Must include a unit (h, m, s, d, y)"
1644
+ )
1645
+ continue
1646
+
1647
+ try:
1648
+ ValueUnit.parse_value_unit(value, is_dose=False)
1649
+ except ValueError as e:
1650
+ errors.append(f"Row {idx} - {str(e)}")
1651
+
1652
+ return errors
1653
+
1654
+
1655
+ class PertAnnDataCatManager(CellxGeneAnnDataCatManager):
1656
+ """Categorical manager for `AnnData` to manage perturbations."""
1657
+
1658
+ PERT_COLUMNS = {"compound", "genetic", "biologic", "physical"}
1659
+
1660
+ def __init__(
1661
+ self,
1662
+ adata: ad.AnnData,
1663
+ organism: Literal["human", "mouse"] = "human",
1664
+ pert_dose: bool = True,
1665
+ pert_time: bool = True,
1666
+ *,
1667
+ cxg_schema_version: Literal["5.0.0", "5.1.0", "5.2.0"] = "5.2.0",
1668
+ ):
1669
+ """Initialize the curator with configuration and validation settings."""
1670
+ self._pert_time = pert_time
1671
+ self._pert_dose = pert_dose
1672
+
1673
+ self._validate_initial_data(adata)
1674
+ categoricals, categoricals_defaults = self._configure_categoricals(adata)
1675
+
1676
+ super().__init__(
1677
+ adata=adata,
1678
+ categoricals=categoricals,
1679
+ defaults=categoricals_defaults,
1680
+ extra_sources=self._configure_sources(adata),
1681
+ schema_version=cxg_schema_version,
1682
+ )
1683
+
1684
+ def _configure_categoricals(self, adata: ad.AnnData):
1685
+ """Set up default configuration values."""
1686
+ import bionty as bt
1687
+ import wetlab as wl
1688
+
1689
+ categoricals = CellxGeneAnnDataCatManager._get_cxg_categoricals() | {
1690
+ k: v
1691
+ for k, v in {
1692
+ "cell_line": bt.CellLine.name,
1693
+ "pert_target": wl.PerturbationTarget.name,
1694
+ "pert_genetic": wl.GeneticPerturbation.name,
1695
+ "pert_compound": wl.Compound.name,
1696
+ "pert_biologic": wl.Biologic.name,
1697
+ "pert_physical": wl.EnvironmentalPerturbation.name,
1698
+ }.items()
1699
+ if k in adata.obs.columns
1700
+ }
1701
+ # if "donor_id" in categoricals:
1702
+ # categoricals["donor_id"] = Donor.name
1703
+
1704
+ categoricals_defaults = CellxGeneAnnDataCatManager.cxg_categoricals_defaults | {
1705
+ "cell_line": "unknown",
1706
+ "pert_target": "unknown",
1707
+ }
1708
+
1709
+ return categoricals, categoricals_defaults
1710
+
1711
+ def _configure_sources(self, adata: ad.AnnData):
1712
+ """Set up data sources."""
1713
+ import bionty as bt
1714
+ import wetlab as wl
1715
+
1716
+ sources = {}
1717
+ # # do not yet specify cell_line source
1718
+ # if "cell_line" in adata.obs.columns:
1719
+ # sources["cell_line"] = bt.Source.filter(
1720
+ # entity="bionty.CellLine", name="depmap"
1721
+ # ).first()
1722
+ if "pert_compound" in adata.obs.columns:
1723
+ with logger.mute():
1724
+ chebi_source = bt.Source.filter(
1725
+ entity="wetlab.Compound", name="chebi"
1726
+ ).first()
1727
+ if not chebi_source:
1728
+ wl.Compound.add_source(
1729
+ bt.Source.filter(entity="Drug", name="chebi").first()
1730
+ )
1731
+
1732
+ sources["pert_compound"] = bt.Source.filter(
1733
+ entity="wetlab.Compound", name="chebi"
1734
+ ).first()
1735
+ return sources
1736
+
1737
+ def _validate_initial_data(self, adata: ad.AnnData):
1738
+ """Validate the initial data structure."""
1739
+ self._validate_required_columns(adata)
1740
+ self._validate_perturbation_types(adata)
1741
+
1742
+ def _validate_required_columns(self, adata: ad.AnnData):
1743
+ """Validate required columns are present."""
1744
+ if "pert_target" not in adata.obs.columns:
1745
+ if (
1746
+ "pert_name" not in adata.obs.columns
1747
+ or "pert_type" not in adata.obs.columns
1748
+ ):
1749
+ raise ValidationError(
1750
+ "either 'pert_target' or both 'pert_name' and 'pert_type' must be present"
1751
+ )
1752
+ else:
1753
+ if "pert_name" not in adata.obs.columns:
1754
+ logger.warning(
1755
+ "no 'pert' column found in adata.obs, will only curate 'pert_target'"
1756
+ )
1757
+ elif "pert_type" not in adata.obs.columns:
1758
+ raise ValidationError("both 'pert' and 'pert_type' must be present")
1759
+
1760
+ def _validate_perturbation_types(self, adata: ad.AnnData):
1761
+ """Validate perturbation types."""
1762
+ if "pert_type" in adata.obs.columns:
1763
+ data_pert_types = set(adata.obs["pert_type"].unique())
1764
+ invalid_pert_types = data_pert_types - self.PERT_COLUMNS
1765
+ if invalid_pert_types:
1766
+ raise ValidationError(
1767
+ f"invalid pert_type found: {invalid_pert_types}!\n"
1768
+ f" → allowed values: {self.PERT_COLUMNS}"
1769
+ )
1770
+ self._process_perturbation_types(adata, data_pert_types)
1771
+
1772
+ def _process_perturbation_types(self, adata: ad.AnnData, pert_types: set):
1773
+ """Process and map perturbation types."""
1774
+ for pert_type in pert_types:
1775
+ col_name = "pert_" + pert_type
1776
+ adata.obs[col_name] = adata.obs["pert_name"].where(
1777
+ adata.obs["pert_type"] == pert_type, None
1778
+ )
1779
+ if adata.obs[col_name].dtype.name == "category":
1780
+ adata.obs[col_name].cat.remove_unused_categories()
1781
+ logger.important(f"mapped 'pert_name' to '{col_name}'")
1782
+
1783
+ def validate(self) -> bool: # type: ignore
1784
+ """Validate the AnnData object."""
1785
+ validated = super().validate()
1786
+
1787
+ if self._pert_dose:
1788
+ validated &= self._validate_dose_column()
1789
+ if self._pert_time:
1790
+ validated &= self._validate_time_column()
1791
+
1792
+ self._is_validated = validated
1793
+
1794
+ # sort columns
1795
+ first_columns = [
1796
+ "pert_target",
1797
+ "pert_genetic",
1798
+ "pert_compound",
1799
+ "pert_biologic",
1800
+ "pert_physical",
1801
+ "pert_dose",
1802
+ "pert_time",
1803
+ "organism",
1804
+ "cell_line",
1805
+ "cell_type",
1806
+ "disease",
1807
+ "tissue_type",
1808
+ "tissue",
1809
+ "assay",
1810
+ "suspension_type",
1811
+ "donor_id",
1812
+ "sex",
1813
+ "self_reported_ethnicity",
1814
+ "development_stage",
1815
+ "pert_name",
1816
+ "pert_type",
1817
+ ]
1818
+ sorted_columns = [
1819
+ col for col in first_columns if col in self._adata.obs.columns
1820
+ ] + [col for col in self._adata.obs.columns if col not in first_columns]
1821
+ # must assign to self._df to ensure .standardize works correctly
1822
+ self._obs_df = self._adata.obs[sorted_columns]
1823
+ self._adata.obs = self._obs_df
1824
+ return validated
1825
+
1826
+ def standardize(self, key: str) -> pd.DataFrame:
1827
+ """Standardize the AnnData object."""
1828
+ super().standardize(key)
1829
+ self._adata.obs = self._obs_df
1830
+
1831
+ def _validate_dose_column(self) -> bool:
1832
+ """Validate the dose column."""
1833
+ if not Feature.filter(name="pert_dose").exists():
1834
+ Feature(name="pert_dose", dtype="str").save() # type: ignore
1835
+
1836
+ dose_errors = DoseHandler.validate_values(self._adata.obs["pert_dose"])
1837
+ if dose_errors:
1838
+ self._log_validation_errors("pert_dose", dose_errors)
1839
+ return False
1840
+ return True
1841
+
1842
+ def _validate_time_column(self) -> bool:
1843
+ """Validate the time column."""
1844
+ if not Feature.filter(name="pert_time").exists():
1845
+ Feature(name="pert_time", dtype="str").save() # type: ignore
1846
+
1847
+ time_errors = TimeHandler.validate_values(self._adata.obs["pert_time"])
1848
+ if time_errors:
1849
+ self._log_validation_errors("pert_time", time_errors)
1850
+ return False
1851
+ return True
1852
+
1853
+ def _log_validation_errors(self, column: str, errors: list):
1854
+ """Log validation errors with formatting."""
1855
+ errors_print = "\n ".join(errors)
1856
+ logger.warning(
1857
+ f"invalid {column} values found!\n {errors_print}\n"
1858
+ f" → run {colors.cyan('standardize_dose_time()')}"
1859
+ )
1860
+
1861
+ def standardize_dose_time(self) -> pd.DataFrame:
1862
+ """Standardize dose and time values."""
1863
+ standardized_df = self._adata.obs.copy()
1864
+
1865
+ if "pert_dose" in self._adata.obs.columns:
1866
+ standardized_df = self._standardize_column(
1867
+ standardized_df, "pert_dose", is_dose=True
1868
+ )
1869
+
1870
+ if "pert_time" in self._adata.obs.columns:
1871
+ standardized_df = self._standardize_column(
1872
+ standardized_df, "pert_time", is_dose=False
1873
+ )
1874
+
1875
+ self._adata.obs = standardized_df
1876
+ return standardized_df
1877
+
1878
+ def _standardize_column(
1879
+ self, df: pd.DataFrame, column: str, is_dose: bool
1880
+ ) -> pd.DataFrame:
1881
+ """Standardize values in a specific column."""
1882
+ for idx, value in self._adata.obs[column].items():
1883
+ if pd.isna(value) or (
1884
+ isinstance(value, str) and (not value.strip() or value.lower() == "nan")
1885
+ ):
1886
+ df.at[idx, column] = None
1887
+ continue
1888
+
1889
+ try:
1890
+ num, unit = ValueUnit.parse_value_unit(value, is_dose=is_dose)
1891
+ df.at[idx, column] = f"{num}{unit}"
1892
+ except ValueError:
1893
+ continue
1894
+
1895
+ return df
1896
+
1897
+
1898
+ def legacy_annotate_artifact(
1899
+ artifact: Artifact,
1900
+ *,
1901
+ cat_vectors: dict[str, CatVector] | None = None,
1902
+ index_field: FieldAttr | dict[str, FieldAttr] | None = None,
1903
+ **kwargs,
1904
+ ) -> Artifact:
1905
+ from ..models.artifact import add_labels
1906
+
1907
+ if cat_vectors is None:
1908
+ cat_vectors = {}
1909
+
1910
+ # annotate with labels
1911
+ for key, cat_vector in cat_vectors.items():
1912
+ if (
1913
+ cat_vector._field.field.model == Feature
1914
+ or key == "columns"
1915
+ or key == "var_index"
1916
+ ):
1917
+ continue
1918
+ add_labels(
1919
+ artifact,
1920
+ records=cat_vector.records,
1921
+ feature=cat_vector.feature,
1922
+ feature_ref_is_name=None, # do not need anymore
1923
+ label_ref_is_name=cat_vector.label_ref_is_name,
1924
+ from_curator=True,
1925
+ )
1926
+
1927
+ match artifact.otype:
1928
+ case "DataFrame":
1929
+ artifact.features._add_set_from_df(field=index_field) # type: ignore
1930
+ case "AnnData":
1931
+ artifact.features._add_set_from_anndata( # type: ignore
1932
+ var_field=index_field,
1933
+ )
1934
+ case "MuData":
1935
+ artifact.features._add_set_from_mudata(var_fields=index_field) # type: ignore
1936
+ case "SpatialData":
1937
+ artifact.features._add_set_from_spatialdata( # type: ignore
1938
+ sample_metadata_key=kwargs.get("sample_metadata_key", "sample"),
1939
+ var_fields=index_field,
1940
+ )
1941
+ case _:
1942
+ raise NotImplementedError # pragma: no cover
1943
+
1944
+ return artifact
1945
+
1946
+
1947
+ # backward compat constructors ------------------
1948
+
1949
+
1950
+ @classmethod # type: ignore
1951
+ def from_df(
1952
+ cls,
1953
+ df: pd.DataFrame,
1954
+ categoricals: dict[str, FieldAttr] | None = None,
1955
+ columns: FieldAttr = Feature.name,
1956
+ organism: str | None = None,
1957
+ ) -> DataFrameCatManager:
1958
+ if organism is not None:
1959
+ logger.warning("organism is ignored, define it on the dtype level")
1960
+ return DataFrameCatManager(
1961
+ df=df,
1962
+ categoricals=categoricals,
1963
+ columns_field=columns,
1964
+ )
1965
+
1966
+
1967
+ @classmethod # type: ignore
1968
+ def from_anndata(
1969
+ cls,
1970
+ data: ad.AnnData | UPathStr,
1971
+ var_index: FieldAttr,
1972
+ categoricals: dict[str, FieldAttr] | None = None,
1973
+ obs_columns: FieldAttr = Feature.name,
1974
+ organism: str | None = None,
1975
+ sources: dict[str, Record] | None = None,
1976
+ ) -> AnnDataCatManager:
1977
+ if organism is not None:
1978
+ logger.warning("organism is ignored, define it on the dtype level")
1979
+ return AnnDataCatManager(
1980
+ data=data,
1981
+ var_index=var_index,
1982
+ categoricals=categoricals,
1983
+ obs_columns=obs_columns,
1984
+ sources=sources,
1985
+ )
1986
+
1987
+
1988
+ @classmethod # type: ignore
1989
+ def from_mudata(
1990
+ cls,
1991
+ mdata: MuData | UPathStr,
1992
+ var_index: dict[str, dict[str, FieldAttr]],
1993
+ categoricals: dict[str, FieldAttr] | None = None,
1994
+ organism: str | None = None,
1995
+ ) -> MuDataCatManager:
1996
+ if not is_package_installed("mudata"):
1997
+ raise ImportError("Please install mudata: pip install mudata")
1998
+ if organism is not None:
1999
+ logger.warning("organism is ignored, define it on the dtype level")
2000
+ return MuDataCatManager(
2001
+ mdata=mdata,
2002
+ var_index=var_index,
2003
+ categoricals=categoricals,
2004
+ )
2005
+
2006
+
2007
+ @classmethod # type: ignore
2008
+ def from_tiledbsoma(
2009
+ cls,
2010
+ experiment_uri: UPathStr,
2011
+ var_index: dict[str, tuple[str, FieldAttr]],
2012
+ categoricals: dict[str, FieldAttr] | None = None,
2013
+ obs_columns: FieldAttr = Feature.name,
2014
+ organism: str | None = None,
2015
+ sources: dict[str, Record] | None = None,
2016
+ ) -> TiledbsomaCatManager:
2017
+ if organism is not None:
2018
+ logger.warning("organism is ignored, define it on the dtype level")
2019
+ return TiledbsomaCatManager(
2020
+ experiment_uri=experiment_uri,
2021
+ var_index=var_index,
2022
+ categoricals=categoricals,
2023
+ obs_columns=obs_columns,
2024
+ sources=sources,
2025
+ )
2026
+
2027
+
2028
+ @classmethod # type: ignore
2029
+ def from_spatialdata(
2030
+ cls,
2031
+ sdata: SpatialData | UPathStr,
2032
+ var_index: dict[str, FieldAttr],
2033
+ categoricals: dict[str, dict[str, FieldAttr]] | None = None,
2034
+ organism: str | None = None,
2035
+ sources: dict[str, dict[str, Record]] | None = None,
2036
+ *,
2037
+ sample_metadata_key: str = "sample",
2038
+ ):
2039
+ if not is_package_installed("spatialdata"):
2040
+ raise ImportError("Please install spatialdata: pip install spatialdata")
2041
+ if organism is not None:
2042
+ logger.warning("organism is ignored, define it on the dtype level")
2043
+ return SpatialDataCatManager(
2044
+ sdata=sdata,
2045
+ var_index=var_index,
2046
+ categoricals=categoricals,
2047
+ sources=sources,
2048
+ sample_metadata_key=sample_metadata_key,
2049
+ )
2050
+
2051
+
2052
+ CatManager.from_df = from_df # type: ignore
2053
+ CatManager.from_anndata = from_anndata # type: ignore
2054
+ CatManager.from_mudata = from_mudata # type: ignore
2055
+ CatManager.from_spatialdata = from_spatialdata # type: ignore
2056
+ CatManager.from_tiledbsoma = from_tiledbsoma # type: ignore