lamindb 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. lamindb/__init__.py +52 -36
  2. lamindb/_finish.py +17 -10
  3. lamindb/_tracked.py +1 -1
  4. lamindb/base/__init__.py +3 -1
  5. lamindb/base/fields.py +40 -22
  6. lamindb/base/ids.py +1 -94
  7. lamindb/base/types.py +2 -0
  8. lamindb/base/uids.py +117 -0
  9. lamindb/core/_context.py +203 -102
  10. lamindb/core/_settings.py +38 -25
  11. lamindb/core/datasets/__init__.py +11 -4
  12. lamindb/core/datasets/_core.py +5 -5
  13. lamindb/core/datasets/_small.py +0 -93
  14. lamindb/core/datasets/mini_immuno.py +172 -0
  15. lamindb/core/loaders.py +1 -1
  16. lamindb/core/storage/_backed_access.py +100 -6
  17. lamindb/core/storage/_polars_lazy_df.py +51 -0
  18. lamindb/core/storage/_pyarrow_dataset.py +15 -30
  19. lamindb/core/storage/_tiledbsoma.py +29 -13
  20. lamindb/core/storage/objects.py +6 -0
  21. lamindb/core/subsettings/__init__.py +2 -0
  22. lamindb/core/subsettings/_annotation_settings.py +11 -0
  23. lamindb/curators/__init__.py +7 -3349
  24. lamindb/curators/_legacy.py +2056 -0
  25. lamindb/curators/core.py +1534 -0
  26. lamindb/errors.py +11 -0
  27. lamindb/examples/__init__.py +27 -0
  28. lamindb/examples/schemas/__init__.py +12 -0
  29. lamindb/examples/schemas/_anndata.py +25 -0
  30. lamindb/examples/schemas/_simple.py +19 -0
  31. lamindb/integrations/_vitessce.py +8 -5
  32. lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
  33. lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
  34. lamindb/migrations/0093_alter_schemacomponent_unique_together.py +16 -0
  35. lamindb/models/__init__.py +4 -1
  36. lamindb/models/_describe.py +21 -4
  37. lamindb/models/_feature_manager.py +382 -287
  38. lamindb/models/_label_manager.py +8 -2
  39. lamindb/models/artifact.py +177 -106
  40. lamindb/models/artifact_set.py +122 -0
  41. lamindb/models/collection.py +73 -52
  42. lamindb/models/core.py +1 -1
  43. lamindb/models/feature.py +51 -17
  44. lamindb/models/has_parents.py +69 -14
  45. lamindb/models/project.py +1 -1
  46. lamindb/models/query_manager.py +221 -22
  47. lamindb/models/query_set.py +247 -172
  48. lamindb/models/record.py +65 -247
  49. lamindb/models/run.py +4 -4
  50. lamindb/models/save.py +8 -2
  51. lamindb/models/schema.py +456 -184
  52. lamindb/models/transform.py +2 -2
  53. lamindb/models/ulabel.py +8 -5
  54. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/METADATA +6 -6
  55. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/RECORD +57 -43
  56. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/LICENSE +0 -0
  57. {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1534 @@
1
+ """Curator utilities.
2
+
3
+ .. autosummary::
4
+ :toctree: .
5
+
6
+ Curator
7
+ SlotsCurator
8
+ CatVector
9
+ CatLookup
10
+ DataFrameCatManager
11
+
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import re
18
+ from typing import TYPE_CHECKING, Any, Callable
19
+
20
+ import lamindb_setup as ln_setup
21
+ import numpy as np
22
+ import pandas as pd
23
+ import pandera
24
+ from lamin_utils import colors, logger
25
+ from lamindb_setup.core._docs import doc_args
26
+
27
+ from lamindb.base.types import FieldAttr # noqa
28
+ from lamindb.models import (
29
+ Artifact,
30
+ Feature,
31
+ Record,
32
+ Run,
33
+ Schema,
34
+ )
35
+ from lamindb.models._from_values import _format_values
36
+ from lamindb.models.artifact import (
37
+ data_is_anndata,
38
+ data_is_mudata,
39
+ data_is_spatialdata,
40
+ )
41
+ from lamindb.models.feature import parse_cat_dtype, parse_dtype
42
+
43
+ from ..errors import InvalidArgument, ValidationError
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import Iterable
47
+ from typing import Any
48
+
49
+ from anndata import AnnData
50
+ from mudata import MuData
51
+ from spatialdata import SpatialData
52
+
53
+ from lamindb.models.query_set import RecordList
54
+
55
+
56
+ def strip_ansi_codes(text):
57
+ # This pattern matches ANSI escape sequences
58
+ ansi_pattern = re.compile(r"\x1b\[[0-9;]*m")
59
+ return ansi_pattern.sub("", text)
60
+
61
+
62
+ class CatLookup:
63
+ """Lookup categories from the reference instance.
64
+
65
+ Args:
66
+ categoricals: A dictionary of categorical fields to lookup.
67
+ slots: A dictionary of slot fields to lookup.
68
+ public: Whether to lookup from the public instance. Defaults to False.
69
+
70
+ Example::
71
+
72
+ curator = ln.curators.DataFrameCurator(...)
73
+ curator.cat.lookup()["cell_type"].alveolar_type_1_fibroblast_cell
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ categoricals: list[Feature] | dict[str, FieldAttr],
80
+ slots: dict[str, FieldAttr] = None,
81
+ public: bool = False,
82
+ sources: dict[str, Record] | None = None,
83
+ ) -> None:
84
+ slots = slots or {}
85
+ if isinstance(categoricals, list):
86
+ categoricals = {
87
+ feature.name: parse_dtype(feature.dtype)[0]["field"]
88
+ for feature in categoricals
89
+ }
90
+ self._categoricals = {**categoricals, **slots}
91
+ self._public = public
92
+ self._sources = sources
93
+
94
+ def __getattr__(self, name):
95
+ if name in self._categoricals:
96
+ registry = self._categoricals[name].field.model
97
+ if self._public and hasattr(registry, "public"):
98
+ return registry.public(source=self._sources.get(name)).lookup()
99
+ else:
100
+ return registry.lookup()
101
+ raise AttributeError(
102
+ f'"{self.__class__.__name__}" object has no attribute "{name}"'
103
+ )
104
+
105
+ def __getitem__(self, name):
106
+ if name in self._categoricals:
107
+ registry = self._categoricals[name].field.model
108
+ if self._public and hasattr(registry, "public"):
109
+ return registry.public(source=self._sources.get(name)).lookup()
110
+ else:
111
+ return registry.lookup()
112
+ raise AttributeError(
113
+ f'"{self.__class__.__name__}" object has no attribute "{name}"'
114
+ )
115
+
116
+ def __repr__(self) -> str:
117
+ if len(self._categoricals) > 0:
118
+ getattr_keys = "\n ".join(
119
+ [f".{key}" for key in self._categoricals if key.isidentifier()]
120
+ )
121
+ getitem_keys = "\n ".join(
122
+ [str([key]) for key in self._categoricals if not key.isidentifier()]
123
+ )
124
+ ref = "public" if self._public else "registries"
125
+ return (
126
+ f"Lookup objects from the {colors.italic(ref)}:\n "
127
+ f"{colors.green(getattr_keys)}\n "
128
+ f"{colors.green(getitem_keys)}\n"
129
+ 'Example:\n → categories = curator.lookup()["cell_type"]\n'
130
+ " → categories.alveolar_type_1_fibroblast_cell\n\n"
131
+ "To look up public ontologies, use .lookup(public=True)"
132
+ )
133
+ else: # pragma: no cover
134
+ return colors.warning("No fields are found!")
135
+
136
+
137
+ CAT_MANAGER_DOCSTRING = """Manage categoricals by updating registries."""
138
+
139
+
140
+ SLOTS_DOCSTRING = """Access sub curators by slot."""
141
+
142
+
143
+ VALIDATE_DOCSTRING = """Validate dataset against Schema.
144
+
145
+ Raises:
146
+ lamindb.errors.ValidationError: If validation fails.
147
+ """
148
+
149
+ SAVE_ARTIFACT_DOCSTRING = """Save an annotated artifact.
150
+
151
+ Args:
152
+ key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
153
+ description: A description.
154
+ revises: Previous version of the artifact. Is an alternative way to passing `key` to trigger a new version.
155
+ run: The run that creates the artifact.
156
+
157
+ Returns:
158
+ A saved artifact record.
159
+ """
160
+
161
+
162
+ class Curator:
163
+ """Curator base class.
164
+
165
+ A `Curator` object makes it easy to validate, standardize & annotate datasets.
166
+
167
+ See:
168
+ - :class:`~lamindb.curators.DataFrameCurator`
169
+ - :class:`~lamindb.curators.AnnDataCurator`
170
+ - :class:`~lamindb.curators.MuDataCurator`
171
+ - :class:`~lamindb.curators.SpatialDataCurator`
172
+ """
173
+
174
+ def __init__(self, dataset: Any, schema: Schema | None = None):
175
+ self._artifact: Artifact = None # pass the dataset as an artifact
176
+ self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
177
+ if isinstance(self._dataset, Artifact):
178
+ self._artifact = self._dataset
179
+ if self._artifact.otype in {
180
+ "DataFrame",
181
+ "AnnData",
182
+ "MuData",
183
+ "SpatialData",
184
+ }:
185
+ self._dataset = self._dataset.load(is_run_input=False)
186
+ self._schema: Schema | None = schema
187
+ self._is_validated: bool = False
188
+
189
+ @doc_args(VALIDATE_DOCSTRING)
190
+ def validate(self) -> bool | str:
191
+ """{}""" # noqa: D415
192
+ pass # pragma: no cover
193
+
194
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
195
+ def save_artifact(
196
+ self,
197
+ *,
198
+ key: str | None = None,
199
+ description: str | None = None,
200
+ revises: Artifact | None = None,
201
+ run: Run | None = None,
202
+ ) -> Artifact:
203
+ """{}""" # noqa: D415
204
+ # Note that this docstring has to be consistent with the Artifact()
205
+ # constructor signature
206
+ pass # pragma: no cover
207
+
208
+ def __repr__(self) -> str:
209
+ from lamin_utils import colors
210
+
211
+ if self._schema is not None:
212
+ # Schema might have different attributes
213
+ if hasattr(self._schema, "name") and self._schema.name:
214
+ schema_str = colors.italic(self._schema.name)
215
+ elif hasattr(self._schema, "uid"):
216
+ schema_str = colors.italic(f"uid={self._schema.uid}")
217
+ elif hasattr(self._schema, "id"):
218
+ schema_str = colors.italic(f"id={self._schema.id}")
219
+ else:
220
+ schema_str = colors.italic("unnamed")
221
+
222
+ # Add schema type info if available
223
+ if hasattr(self._schema, "otype") and self._schema.otype:
224
+ schema_str += f" ({self._schema.otype})"
225
+ else:
226
+ schema_str = colors.warning("None")
227
+
228
+ status_str = ""
229
+ if self._is_validated:
230
+ status_str = f", {colors.green('validated')}"
231
+ else:
232
+ status_str = f", {colors.yellow('unvalidated')}"
233
+
234
+ cls_name = colors.green(self.__class__.__name__)
235
+
236
+ # Get additional info based on curator type
237
+ extra_info = ""
238
+ if hasattr(self, "_slots") and self._slots:
239
+ # For SlotsCurator and its subclasses
240
+ slots_count = len(self._slots)
241
+ if slots_count > 0:
242
+ slot_names = list(self._slots.keys())
243
+ if len(slot_names) <= 3:
244
+ extra_info = f", slots: {slot_names}"
245
+ else:
246
+ extra_info = f", slots: [{', '.join(slot_names[:3])}... +{len(slot_names) - 3} more]"
247
+ elif (
248
+ cls_name == "DataFrameCurator"
249
+ and hasattr(self, "cat")
250
+ and hasattr(self.cat, "_categoricals")
251
+ ):
252
+ # For DataFrameCurator
253
+ cat_count = len(getattr(self.cat, "_categoricals", []))
254
+ if cat_count > 0:
255
+ extra_info = f", categorical_features={cat_count}"
256
+
257
+ artifact_info = ""
258
+ if self._artifact is not None:
259
+ artifact_uid = getattr(self._artifact, "uid", str(self._artifact))
260
+ short_uid = (
261
+ str(artifact_uid)[:8] + "..."
262
+ if len(str(artifact_uid)) > 8
263
+ else str(artifact_uid)
264
+ )
265
+ artifact_info = f", artifact: {colors.italic(short_uid)}"
266
+
267
+ return (
268
+ f"{cls_name}{artifact_info}(Schema: {schema_str}{extra_info}{status_str})"
269
+ )
270
+
271
+
272
+ # default implementation for AnnDataCurator, MuDataCurator, and SpatialDataCurator
273
+ class SlotsCurator(Curator):
274
+ """Curator for a dataset with slots.
275
+
276
+ Args:
277
+ dataset: The dataset to validate & annotate.
278
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
279
+
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ dataset: Any,
285
+ schema: Schema,
286
+ ) -> None:
287
+ super().__init__(dataset=dataset, schema=schema)
288
+ self._slots: dict[str, DataFrameCurator] = {}
289
+
290
+ # used in MuDataCurator and SpatialDataCurator
291
+ # in form of {table/modality_key: var_field}
292
+ self._var_fields: dict[str, FieldAttr] = {}
293
+ # in form of {table/modality_key: categoricals}
294
+ self._cat_vectors: dict[str, dict[str, CatVector]] = {}
295
+
296
+ @property
297
+ @doc_args(SLOTS_DOCSTRING)
298
+ def slots(self) -> dict[str, DataFrameCurator]:
299
+ """{}""" # noqa: D415
300
+ return self._slots
301
+
302
+ @doc_args(VALIDATE_DOCSTRING)
303
+ def validate(self) -> None:
304
+ """{}""" # noqa: D415
305
+ for slot, curator in self._slots.items():
306
+ logger.info(f"validating slot {slot} ...")
307
+ curator.validate()
308
+
309
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
310
+ def save_artifact(
311
+ self,
312
+ *,
313
+ key: str | None = None,
314
+ description: str | None = None,
315
+ revises: Artifact | None = None,
316
+ run: Run | None = None,
317
+ ) -> Artifact:
318
+ """{}""" # noqa: D415
319
+ if not self._is_validated:
320
+ self.validate()
321
+ if self._artifact is None:
322
+ if data_is_anndata(self._dataset):
323
+ self._artifact = Artifact.from_anndata(
324
+ self._dataset,
325
+ key=key,
326
+ description=description,
327
+ revises=revises,
328
+ run=run,
329
+ )
330
+ if data_is_mudata(self._dataset):
331
+ self._artifact = Artifact.from_mudata(
332
+ self._dataset,
333
+ key=key,
334
+ description=description,
335
+ revises=revises,
336
+ run=run,
337
+ )
338
+ elif data_is_spatialdata(self._dataset):
339
+ self._artifact = Artifact.from_spatialdata(
340
+ self._dataset,
341
+ key=key,
342
+ description=description,
343
+ revises=revises,
344
+ run=run,
345
+ )
346
+ self._artifact.schema = self._schema
347
+ self._artifact.save()
348
+ cat_vectors = {}
349
+ for curator in self._slots.values():
350
+ for key, cat_vector in curator.cat._cat_vectors.items():
351
+ cat_vectors[key] = cat_vector
352
+ return annotate_artifact( # type: ignore
353
+ self._artifact,
354
+ curator=self,
355
+ cat_vectors=cat_vectors,
356
+ )
357
+
358
+
359
+ def check_dtype(expected_type) -> Callable:
360
+ """Creates a check function for Pandera that validates a column's dtype.
361
+
362
+ Args:
363
+ expected_type: String identifier for the expected type ('int', 'float', or 'num')
364
+
365
+ Returns:
366
+ A function that checks if a series has the expected dtype
367
+ """
368
+
369
+ def check_function(series):
370
+ if expected_type == "int":
371
+ is_valid = pd.api.types.is_integer_dtype(series.dtype)
372
+ elif expected_type == "float":
373
+ is_valid = pd.api.types.is_float_dtype(series.dtype)
374
+ elif expected_type == "num":
375
+ is_valid = pd.api.types.is_numeric_dtype(series.dtype)
376
+ return is_valid
377
+
378
+ return check_function
379
+
380
+
381
+ # this is also currently used as DictCurator
382
+ class DataFrameCurator(Curator):
383
+ # the example in the docstring is tested in test_curators_quickstart_example
384
+ """Curator for `DataFrame`.
385
+
386
+ Args:
387
+ dataset: The DataFrame-like object to validate & annotate.
388
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
389
+ slot: Indicate the slot in a composite curator for a composite data structure.
390
+
391
+ Example:
392
+
393
+ For simple example using a flexible schema, see :meth:`~lamindb.Artifact.from_df`.
394
+
395
+ Here is an example that enforces a minimal set of columns in the dataframe.
396
+
397
+ .. literalinclude:: scripts/curate_dataframe_minimal_errors.py
398
+ :language: python
399
+
400
+ Under-the-hood, this used the following schema.
401
+
402
+ .. literalinclude:: scripts/define_mini_immuno_schema_flexible.py
403
+ :language: python
404
+
405
+ Valid features & labels were defined as:
406
+
407
+ .. literalinclude:: scripts/define_mini_immuno_features_labels.py
408
+ :language: python
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ dataset: pd.DataFrame | Artifact,
414
+ schema: Schema,
415
+ slot: str | None = None,
416
+ ) -> None:
417
+ super().__init__(dataset=dataset, schema=schema)
418
+ categoricals = []
419
+ features = []
420
+ feature_ids: set[int] = set()
421
+ if schema.flexible:
422
+ features += Feature.filter(name__in=self._dataset.keys()).list()
423
+ feature_ids = {feature.id for feature in features}
424
+ if schema.n > 0:
425
+ if schema._index_feature_uid is not None:
426
+ schema_features = [
427
+ feature
428
+ for feature in schema.members.list()
429
+ if feature.uid != schema._index_feature_uid # type: ignore
430
+ ]
431
+ else:
432
+ schema_features = schema.members.list() # type: ignore
433
+ if feature_ids:
434
+ features.extend(
435
+ feature
436
+ for feature in schema_features
437
+ if feature.id not in feature_ids # type: ignore
438
+ )
439
+ else:
440
+ features.extend(schema_features)
441
+ else:
442
+ assert schema.itype is not None # noqa: S101
443
+ pandera_columns = {}
444
+ if features or schema._index_feature_uid is not None:
445
+ # populate features
446
+ if schema.minimal_set:
447
+ optional_feature_uids = set(schema.optionals.get_uids())
448
+ for feature in features:
449
+ if schema.minimal_set:
450
+ required = feature.uid not in optional_feature_uids
451
+ else:
452
+ required = False
453
+ if feature.dtype in {"int", "float", "num"}:
454
+ if isinstance(self._dataset, pd.DataFrame):
455
+ dtype = (
456
+ self._dataset[feature.name].dtype
457
+ if feature.name in self._dataset.keys()
458
+ else None
459
+ )
460
+ else:
461
+ dtype = None
462
+ pandera_columns[feature.name] = pandera.Column(
463
+ dtype=None,
464
+ checks=pandera.Check(
465
+ check_dtype(feature.dtype),
466
+ element_wise=False,
467
+ error=f"Column '{feature.name}' failed dtype check for '{feature.dtype}': got {dtype}",
468
+ ),
469
+ nullable=feature.nullable,
470
+ coerce=feature.coerce_dtype,
471
+ required=required,
472
+ )
473
+ else:
474
+ pandera_dtype = (
475
+ feature.dtype
476
+ if not feature.dtype.startswith("cat")
477
+ else "category"
478
+ )
479
+ pandera_columns[feature.name] = pandera.Column(
480
+ pandera_dtype,
481
+ nullable=feature.nullable,
482
+ coerce=feature.coerce_dtype,
483
+ required=required,
484
+ )
485
+ if feature.dtype.startswith("cat"):
486
+ # validate categoricals if the column is required or if the column is present
487
+ if required or feature.name in self._dataset.keys():
488
+ categoricals.append(feature)
489
+ if schema._index_feature_uid is not None:
490
+ # in almost no case, an index should have a pandas.CategoricalDtype in a DataFrame
491
+ # so, we're typing it as `str` here
492
+ index = pandera.Index(
493
+ schema.index.dtype
494
+ if not schema.index.dtype.startswith("cat")
495
+ else str
496
+ )
497
+ else:
498
+ index = None
499
+ self._pandera_schema = pandera.DataFrameSchema(
500
+ pandera_columns,
501
+ coerce=schema.coerce_dtype,
502
+ strict=schema.maximal_set,
503
+ ordered=schema.ordered_set,
504
+ index=index,
505
+ )
506
+ self._cat_manager = DataFrameCatManager(
507
+ self._dataset,
508
+ columns_field=parse_cat_dtype(schema.itype, is_itype=True)["field"],
509
+ columns_names=pandera_columns.keys(),
510
+ categoricals=categoricals,
511
+ index=schema.index,
512
+ slot=slot,
513
+ maximal_set=schema.maximal_set,
514
+ )
515
+
516
+ @property
517
+ @doc_args(CAT_MANAGER_DOCSTRING)
518
+ def cat(self) -> DataFrameCatManager:
519
+ """{}""" # noqa: D415
520
+ return self._cat_manager
521
+
522
+ def standardize(self) -> None:
523
+ """Standardize the dataset.
524
+
525
+ - Adds missing columns for features
526
+ - Fills missing values for features with default values
527
+ """
528
+ for feature in self._schema.members:
529
+ if feature.name not in self._dataset.columns:
530
+ if feature.default_value is not None or feature.nullable:
531
+ fill_value = (
532
+ feature.default_value
533
+ if feature.default_value is not None
534
+ else pd.NA
535
+ )
536
+ if feature.dtype.startswith("cat"):
537
+ self._dataset[feature.name] = pd.Categorical(
538
+ [fill_value] * len(self._dataset)
539
+ )
540
+ else:
541
+ self._dataset[feature.name] = fill_value
542
+ logger.important(
543
+ f"added column {feature.name} with fill value {fill_value}"
544
+ )
545
+ else:
546
+ raise ValidationError(
547
+ f"Missing column {feature.name} cannot be added because is not nullable and has no default value"
548
+ )
549
+ else:
550
+ if feature.default_value is not None:
551
+ if isinstance(
552
+ self._dataset[feature.name].dtype, pd.CategoricalDtype
553
+ ):
554
+ if (
555
+ feature.default_value
556
+ not in self._dataset[feature.name].cat.categories
557
+ ):
558
+ self._dataset[feature.name] = self._dataset[
559
+ feature.name
560
+ ].cat.add_categories(feature.default_value)
561
+ self._dataset[feature.name] = self._dataset[feature.name].fillna(
562
+ feature.default_value
563
+ )
564
+
565
+ def _cat_manager_validate(self) -> None:
566
+ self.cat.validate()
567
+
568
+ if self.cat._is_validated:
569
+ self._is_validated = True
570
+ else:
571
+ self._is_validated = False
572
+ raise ValidationError(self.cat._validate_category_error_messages)
573
+
574
+ @doc_args(VALIDATE_DOCSTRING)
575
+ def validate(self) -> None:
576
+ """{}""" # noqa: D415
577
+ if self._schema.n > 0:
578
+ try:
579
+ # first validate through pandera
580
+ self._pandera_schema.validate(self._dataset)
581
+ # then validate lamindb categoricals
582
+ self._cat_manager_validate()
583
+ except pandera.errors.SchemaError as err:
584
+ self._is_validated = False
585
+ # .exconly() doesn't exist on SchemaError
586
+ raise ValidationError(str(err)) from err
587
+ else:
588
+ self._cat_manager_validate()
589
+
590
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
591
+ def save_artifact(
592
+ self,
593
+ *,
594
+ key: str | None = None,
595
+ description: str | None = None,
596
+ revises: Artifact | None = None,
597
+ run: Run | None = None,
598
+ ) -> Artifact:
599
+ """{}""" # noqa: D415
600
+ if not self._is_validated:
601
+ self.validate() # raises ValidationError if doesn't validate
602
+ if self._artifact is None:
603
+ self._artifact = Artifact.from_df(
604
+ self._dataset,
605
+ key=key,
606
+ description=description,
607
+ revises=revises,
608
+ run=run,
609
+ format=".csv" if key.endswith(".csv") else None,
610
+ )
611
+ self._artifact.schema = self._schema
612
+ self._artifact.save()
613
+ return annotate_artifact( # type: ignore
614
+ self._artifact,
615
+ cat_vectors=self.cat._cat_vectors,
616
+ )
617
+
618
+
619
+ class AnnDataCurator(SlotsCurator):
620
+ """Curator for `AnnData`.
621
+
622
+ Args:
623
+ dataset: The AnnData-like object to validate & annotate.
624
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
625
+
626
+ Example:
627
+
628
+ See :meth:`~lamindb.Artifact.from_anndata`.
629
+
630
+ """
631
+
632
+ def __init__(
633
+ self,
634
+ dataset: AnnData | Artifact,
635
+ schema: Schema,
636
+ ) -> None:
637
+ super().__init__(dataset=dataset, schema=schema)
638
+ if not data_is_anndata(self._dataset):
639
+ raise InvalidArgument("dataset must be AnnData-like.")
640
+ if schema.otype != "AnnData":
641
+ raise InvalidArgument("Schema otype must be 'AnnData'.")
642
+ self._slots = {
643
+ slot: DataFrameCurator(
644
+ (
645
+ getattr(self._dataset, slot.strip(".T")).T
646
+ if slot == "var.T"
647
+ or (
648
+ # backward compat
649
+ slot == "var"
650
+ and schema.slots["var"].itype not in {None, "Feature"}
651
+ )
652
+ else getattr(self._dataset, slot)
653
+ ),
654
+ slot_schema,
655
+ slot=slot,
656
+ )
657
+ for slot, slot_schema in schema.slots.items()
658
+ if slot in {"obs", "var", "var.T", "uns"}
659
+ }
660
+ if "var" in self._slots and schema.slots["var"].itype not in {None, "Feature"}:
661
+ logger.warning(
662
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
663
+ )
664
+ self._slots["var"].cat._cat_vectors["var_index"] = self._slots[
665
+ "var"
666
+ ].cat._cat_vectors.pop("columns")
667
+ self._slots["var"].cat._cat_vectors["var_index"]._key = "var_index"
668
+
669
+
670
+ def _assign_var_fields_categoricals_multimodal(
671
+ modality: str | None,
672
+ slot_type: str,
673
+ slot: str,
674
+ slot_schema: Schema,
675
+ var_fields: dict[str, FieldAttr],
676
+ cat_vectors: dict[str, dict[str, CatVector]],
677
+ slots: dict[str, DataFrameCurator],
678
+ ) -> None:
679
+ """Assigns var_fields and categoricals for multimodal data curators."""
680
+ if modality is not None:
681
+ # Makes sure that all tables are present
682
+ var_fields[modality] = None
683
+ cat_vectors[modality] = {}
684
+
685
+ if slot_type == "var":
686
+ var_field = parse_cat_dtype(slot_schema.itype, is_itype=True)["field"]
687
+ if modality is None:
688
+ # This should rarely/never be used since tables should have different var fields
689
+ var_fields[slot] = var_field # pragma: no cover
690
+ else:
691
+ # Note that this is NOT nested since the nested key is always "var"
692
+ var_fields[modality] = var_field
693
+ else:
694
+ obs_fields = slots[slot].cat._cat_vectors
695
+ if modality is None:
696
+ cat_vectors[slot] = obs_fields
697
+ else:
698
+ # Note that this is NOT nested since the nested key is always "obs"
699
+ cat_vectors[modality] = obs_fields
700
+
701
+
702
+ class MuDataCurator(SlotsCurator):
703
+ """Curator for `MuData`.
704
+
705
+ Args:
706
+ dataset: The MuData-like object to validate & annotate.
707
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
708
+
709
+ Example:
710
+
711
+ .. literalinclude:: scripts/curate-mudata.py
712
+ :language: python
713
+ :caption: curate-mudata.py
714
+ """
715
+
716
+ def __init__(
717
+ self,
718
+ dataset: MuData | Artifact,
719
+ schema: Schema,
720
+ ) -> None:
721
+ super().__init__(dataset=dataset, schema=schema)
722
+ if not data_is_mudata(self._dataset):
723
+ raise InvalidArgument("dataset must be MuData-like.")
724
+ if schema.otype != "MuData":
725
+ raise InvalidArgument("Schema otype must be 'MuData'.")
726
+
727
+ for slot, slot_schema in schema.slots.items():
728
+ if ":" in slot:
729
+ modality, modality_slot = slot.split(":")
730
+ schema_dataset = self._dataset.__getitem__(modality)
731
+ else:
732
+ modality, modality_slot = None, slot
733
+ schema_dataset = self._dataset
734
+ if modality_slot == "var" and schema.slots[slot].itype not in {
735
+ None,
736
+ "Feature",
737
+ }:
738
+ logger.warning(
739
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
740
+ )
741
+ self._slots[slot] = DataFrameCurator(
742
+ (
743
+ getattr(schema_dataset, modality_slot.rstrip(".T")).T
744
+ if modality_slot == "var.T"
745
+ or (
746
+ # backward compat
747
+ modality_slot == "var"
748
+ and schema.slots[slot].itype not in {None, "Feature"}
749
+ )
750
+ else getattr(schema_dataset, modality_slot)
751
+ ),
752
+ slot_schema,
753
+ )
754
+ _assign_var_fields_categoricals_multimodal(
755
+ modality=modality,
756
+ slot_type=modality_slot,
757
+ slot=slot,
758
+ slot_schema=slot_schema,
759
+ var_fields=self._var_fields,
760
+ cat_vectors=self._cat_vectors,
761
+ slots=self._slots,
762
+ )
763
+ self._columns_field = self._var_fields
764
+
765
+
766
+ class SpatialDataCurator(SlotsCurator):
767
+ """Curator for `SpatialData`.
768
+
769
+ Args:
770
+ dataset: The SpatialData-like object to validate & annotate.
771
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
772
+
773
+ Example:
774
+
775
+ See :meth:`~lamindb.Artifact.from_spatialdata`.
776
+ """
777
+
778
+ def __init__(
779
+ self,
780
+ dataset: SpatialData | Artifact,
781
+ schema: Schema,
782
+ *,
783
+ sample_metadata_key: str | None = "sample",
784
+ ) -> None:
785
+ super().__init__(dataset=dataset, schema=schema)
786
+ if not data_is_spatialdata(self._dataset):
787
+ raise InvalidArgument("dataset must be SpatialData-like.")
788
+ if schema.otype != "SpatialData":
789
+ raise InvalidArgument("Schema otype must be 'SpatialData'.")
790
+
791
+ for slot, slot_schema in schema.slots.items():
792
+ split_result = slot.split(":")
793
+ if (len(split_result) == 2 and split_result[0] == "table") or (
794
+ len(split_result) == 3 and split_result[0] == "tables"
795
+ ):
796
+ if len(split_result) == 2:
797
+ table_key, sub_slot = split_result
798
+ logger.warning(
799
+ f"please prefix slot {slot} with 'tables:' going forward"
800
+ )
801
+ else:
802
+ table_key, sub_slot = split_result[1], split_result[2]
803
+ slot_object = self._dataset.tables.__getitem__(table_key)
804
+ if sub_slot == "var" and schema.slots[slot].itype not in {
805
+ None,
806
+ "Feature",
807
+ }:
808
+ logger.warning(
809
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
810
+ )
811
+ data_object = (
812
+ getattr(slot_object, sub_slot.rstrip(".T")).T
813
+ if sub_slot == "var.T"
814
+ or (
815
+ # backward compat
816
+ sub_slot == "var"
817
+ and schema.slots[slot].itype not in {None, "Feature"}
818
+ )
819
+ else getattr(slot_object, sub_slot)
820
+ )
821
+ elif len(split_result) == 1 or (
822
+ len(split_result) > 1 and split_result[0] == "attrs"
823
+ ):
824
+ table_key = None
825
+ if len(split_result) == 1:
826
+ if split_result[0] != "attrs":
827
+ logger.warning(
828
+ f"please prefix slot {slot} with 'attrs:' going forward"
829
+ )
830
+ sub_slot = slot
831
+ data_object = self._dataset.attrs[slot]
832
+ else:
833
+ sub_slot = "attrs"
834
+ data_object = self._dataset.attrs
835
+ elif len(split_result) == 2:
836
+ sub_slot = split_result[1]
837
+ data_object = self._dataset.attrs[split_result[1]]
838
+ data_object = pd.DataFrame([data_object])
839
+ self._slots[slot] = DataFrameCurator(data_object, slot_schema, slot)
840
+ _assign_var_fields_categoricals_multimodal(
841
+ modality=table_key,
842
+ slot_type=sub_slot,
843
+ slot=slot,
844
+ slot_schema=slot_schema,
845
+ var_fields=self._var_fields,
846
+ cat_vectors=self._cat_vectors,
847
+ slots=self._slots,
848
+ )
849
+ self._columns_field = self._var_fields
850
+
851
+
852
+ class CatVector:
853
+ """Vector with categorical values."""
854
+
855
+ def __init__(
856
+ self,
857
+ values_getter: Callable
858
+ | Iterable[str], # A callable or iterable that returns the values to validate.
859
+ field: FieldAttr, # The field to validate against.
860
+ key: str, # The name of the vector to validate. Only used for logging.
861
+ values_setter: Callable | None = None, # A callable that sets the values.
862
+ source: Record | None = None, # The ontology source to validate against.
863
+ feature: Feature | None = None,
864
+ cat_manager: DataFrameCatManager | None = None,
865
+ subtype_str: str = "",
866
+ maximal_set: bool = True, # whether unvalidated categoricals cause validation failure.
867
+ ) -> None:
868
+ self._values_getter = values_getter
869
+ self._values_setter = values_setter
870
+ self._field = field
871
+ self._key = key
872
+ self._source = source
873
+ self._organism = None
874
+ self._validated: None | list[str] = None
875
+ self._non_validated: None | list[str] = None
876
+ self._synonyms: None | dict[str, str] = None
877
+ self._subtype_str = subtype_str
878
+ self._subtype_query_set = None
879
+ self._cat_manager = cat_manager
880
+ self.feature = feature
881
+ self.records = None
882
+ self._maximal_set = maximal_set
883
+ if hasattr(field.field.model, "_name_field"):
884
+ label_ref_is_name = field.field.name == field.field.model._name_field
885
+ else:
886
+ label_ref_is_name = field.field.name == "name"
887
+ self.label_ref_is_name = label_ref_is_name
888
+
889
+ @property
890
+ def values(self):
891
+ """Get the current values using the getter function."""
892
+ if callable(self._values_getter):
893
+ return self._values_getter()
894
+ return self._values_getter
895
+
896
+ @values.setter
897
+ def values(self, new_values):
898
+ """Set new values using the setter function if available."""
899
+ if callable(self._values_setter):
900
+ self._values_setter(new_values)
901
+ else:
902
+ # If values_getter is not callable, it's a direct reference we can update
903
+ self._values_getter = new_values
904
+
905
+ @property
906
+ def is_validated(self) -> bool:
907
+ """Whether the vector is validated."""
908
+ # if nothing was validated, something likely is fundamentally wrong
909
+ # should probably add a setting `at_least_one_validated`
910
+ result = True
911
+ if len(self.values) > 0 and len(self.values) == len(self._non_validated):
912
+ result = False
913
+ # len(self._non_validated) != 0
914
+ # if maximal_set is True, return False
915
+ # if maximal_set is False, return True
916
+ # len(self._non_validated) == 0
917
+ # return True
918
+ if len(self._non_validated) != 0:
919
+ if self._maximal_set:
920
+ result = False
921
+ return result
922
+
923
+ def _replace_synonyms(self) -> list[str]:
924
+ """Replace synonyms in the vector with standardized values."""
925
+ syn_mapper = self._synonyms
926
+ # replace the values in df
927
+ std_values = self.values.map(
928
+ lambda unstd_val: syn_mapper.get(unstd_val, unstd_val)
929
+ )
930
+ # remove the standardized values from self.non_validated
931
+ non_validated = [i for i in self._non_validated if i not in syn_mapper]
932
+ if len(non_validated) == 0:
933
+ self._non_validated = []
934
+ else:
935
+ self._non_validated = non_validated # type: ignore
936
+ # logging
937
+ n = len(syn_mapper)
938
+ if n > 0:
939
+ syn_mapper_print = _format_values(
940
+ [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
941
+ )
942
+ s = "s" if n > 1 else ""
943
+ logger.success(
944
+ f'standardized {n} synonym{s} in "{self._key}": {colors.green(syn_mapper_print)}'
945
+ )
946
+ return std_values
947
+
948
+ def __repr__(self) -> str:
949
+ if self._non_validated is None:
950
+ status = "unvalidated"
951
+ else:
952
+ status = (
953
+ "validated"
954
+ if len(self._non_validated) == 0
955
+ else f"non-validated ({len(self._non_validated)})"
956
+ )
957
+
958
+ field_name = getattr(self._field, "name", str(self._field))
959
+ values_count = len(self.values) if hasattr(self.values, "__len__") else "?"
960
+ return f"CatVector(key='{self._key}', field='{field_name}', values={values_count}, {status})"
961
+
962
+ def _add_validated(self) -> tuple[list, list]:
963
+ """Save features or labels records in the default instance."""
964
+ from lamindb.models.save import save as ln_save
965
+
966
+ registry = self._field.field.model
967
+ field_name = self._field.field.name
968
+ model_field = registry.__get_name_with_module__()
969
+ filter_kwargs = get_current_filter_kwargs(
970
+ registry, {"organism": self._organism, "source": self._source}
971
+ )
972
+ values = [i for i in self.values if isinstance(i, str) and i]
973
+ if not values:
974
+ return [], []
975
+ # inspect the default instance and save validated records from public
976
+ if (
977
+ self._subtype_str != "" and "__" not in self._subtype_str
978
+ ): # not for general filter expressions
979
+ self._subtype_query_set = registry.get(name=self._subtype_str).records.all()
980
+ values_array = np.array(values)
981
+ validated_mask = self._subtype_query_set.validate( # type: ignore
982
+ values_array, field=self._field, **filter_kwargs, mute=True
983
+ )
984
+ validated_labels, non_validated_labels = (
985
+ values_array[validated_mask],
986
+ values_array[~validated_mask],
987
+ )
988
+ records = registry.from_values(
989
+ validated_labels, field=self._field, **filter_kwargs, mute=True
990
+ )
991
+ else:
992
+ existing_and_public_records = registry.from_values(
993
+ list(values), field=self._field, **filter_kwargs, mute=True
994
+ )
995
+ existing_and_public_labels = [
996
+ getattr(r, field_name) for r in existing_and_public_records
997
+ ]
998
+ # public records that are not already in the database
999
+ public_records = [r for r in existing_and_public_records if r._state.adding]
1000
+ # here we check to only save the public records if they are from the specified source
1001
+ # we check the uid because r.source and source can be from different instances
1002
+ if self._source:
1003
+ public_records = [
1004
+ r for r in public_records if r.source.uid == self._source.uid
1005
+ ]
1006
+ if len(public_records) > 0:
1007
+ logger.info(f"saving validated records of '{self._key}'")
1008
+ ln_save(public_records)
1009
+ labels_saved_public = [getattr(r, field_name) for r in public_records]
1010
+ # log the saved public labels
1011
+ # the term "transferred" stresses that this is always in the context of transferring
1012
+ # labels from a public ontology or a different instance to the present instance
1013
+ if len(labels_saved_public) > 0:
1014
+ s = "s" if len(labels_saved_public) > 1 else ""
1015
+ logger.success(
1016
+ f'added {len(labels_saved_public)} record{s} {colors.green("from_public")} with {model_field} for "{self._key}": {_format_values(labels_saved_public)}'
1017
+ )
1018
+ # non-validated records from the default instance
1019
+ non_validated_labels = [
1020
+ i for i in values if i not in existing_and_public_labels
1021
+ ]
1022
+ validated_labels = existing_and_public_labels
1023
+ records = existing_and_public_records
1024
+
1025
+ self.records = records
1026
+ # validated, non-validated
1027
+ return validated_labels, non_validated_labels
1028
+
1029
+ def _add_new(
1030
+ self,
1031
+ values: list[str],
1032
+ df: pd.DataFrame | None = None, # remove when all users use schema
1033
+ dtype: str | None = None,
1034
+ **create_kwargs,
1035
+ ) -> None:
1036
+ """Add new labels to the registry."""
1037
+ from lamindb.models.save import save as ln_save
1038
+
1039
+ registry = self._field.field.model
1040
+ field_name = self._field.field.name
1041
+ non_validated_records: RecordList[Any] = [] # type: ignore
1042
+ if df is not None and registry == Feature:
1043
+ nonval_columns = Feature.inspect(df.columns, mute=True).non_validated
1044
+ non_validated_records = Feature.from_df(df.loc[:, nonval_columns])
1045
+ else:
1046
+ if (
1047
+ self._organism
1048
+ and hasattr(registry, "organism")
1049
+ and registry._meta.get_field("organism").is_relation
1050
+ ):
1051
+ # make sure organism record is saved to the current instance
1052
+ create_kwargs["organism"] = _save_organism(name=self._organism)
1053
+
1054
+ for value in values:
1055
+ init_kwargs = {field_name: value}
1056
+ if registry == Feature:
1057
+ init_kwargs["dtype"] = "cat" if dtype is None else dtype
1058
+ non_validated_records.append(registry(**init_kwargs, **create_kwargs))
1059
+ if len(non_validated_records) > 0:
1060
+ ln_save(non_validated_records)
1061
+ model_field = colors.italic(registry.__get_name_with_module__())
1062
+ s = "s" if len(values) > 1 else ""
1063
+ logger.success(
1064
+ f'added {len(values)} record{s} with {model_field} for "{self._key}": {_format_values(values)}'
1065
+ )
1066
+
1067
+ def _validate(
1068
+ self,
1069
+ values: list[str],
1070
+ ) -> tuple[list[str], dict]:
1071
+ """Validate ontology terms using LaminDB registries."""
1072
+ registry = self._field.field.model
1073
+ field_name = self._field.field.name
1074
+ model_field = f"{registry.__name__}.{field_name}"
1075
+
1076
+ kwargs_current = get_current_filter_kwargs(
1077
+ registry, {"organism": self._organism, "source": self._source}
1078
+ )
1079
+
1080
+ # inspect values from the default instance, excluding public
1081
+ registry_or_queryset = registry
1082
+ if self._subtype_query_set is not None:
1083
+ registry_or_queryset = self._subtype_query_set
1084
+ inspect_result = registry_or_queryset.inspect(
1085
+ values, field=self._field, mute=True, from_source=False, **kwargs_current
1086
+ )
1087
+ non_validated = inspect_result.non_validated
1088
+ syn_mapper = inspect_result.synonyms_mapper
1089
+
1090
+ # inspect the non-validated values from public (BioRecord only)
1091
+ values_validated = []
1092
+ if hasattr(registry, "public"):
1093
+ public_records = registry.from_values(
1094
+ non_validated,
1095
+ field=self._field,
1096
+ mute=True,
1097
+ **kwargs_current,
1098
+ )
1099
+ values_validated += [getattr(r, field_name) for r in public_records]
1100
+
1101
+ # logging messages
1102
+ if self._cat_manager is not None:
1103
+ slot = self._cat_manager._slot
1104
+ else:
1105
+ slot = None
1106
+ in_slot = f" in slot '{slot}'" if slot is not None else ""
1107
+ slot_prefix = f".slots['{slot}']" if slot is not None else ""
1108
+ non_validated_hint_print = (
1109
+ f"curator{slot_prefix}.cat.add_new_from('{self._key}')"
1110
+ )
1111
+ non_validated = [i for i in non_validated if i not in values_validated]
1112
+ n_non_validated = len(non_validated)
1113
+ if n_non_validated == 0:
1114
+ logger.success(
1115
+ f'"{self._key}" is validated against {colors.italic(model_field)}'
1116
+ )
1117
+ return [], {}
1118
+ else:
1119
+ s = "" if n_non_validated == 1 else "s"
1120
+ print_values = _format_values(non_validated)
1121
+ warning_message = f"{colors.red(f'{n_non_validated} term{s}')} not validated in feature '{self._key}'{in_slot}: {colors.red(print_values)}\n"
1122
+ if syn_mapper:
1123
+ s = "" if len(syn_mapper) == 1 else "s"
1124
+ syn_mapper_print = _format_values(
1125
+ [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
1126
+ )
1127
+ hint_msg = f'.standardize("{self._key}")'
1128
+ warning_message += f" {colors.yellow(f'{len(syn_mapper)} synonym{s}')} found: {colors.yellow(syn_mapper_print)}\n → curate synonyms via: {colors.cyan(hint_msg)}"
1129
+ if n_non_validated > len(syn_mapper):
1130
+ if syn_mapper:
1131
+ warning_message += "\n for remaining terms:\n"
1132
+ warning_message += f" → fix typos, remove non-existent values, or save terms via: {colors.cyan(non_validated_hint_print)}"
1133
+ if self._subtype_query_set is not None:
1134
+ warning_message += f"\n → a valid label for subtype '{self._subtype_str}' has to be one of {self._subtype_query_set.list('name')}"
1135
+ logger.info(f'mapping "{self._key}" on {colors.italic(model_field)}')
1136
+ logger.warning(warning_message)
1137
+ if self._cat_manager is not None:
1138
+ self._cat_manager._validate_category_error_messages = strip_ansi_codes(
1139
+ warning_message
1140
+ )
1141
+ return non_validated, syn_mapper
1142
+
1143
+ def validate(self) -> None:
1144
+ """Validate the vector."""
1145
+ # add source-validated values to the registry
1146
+ self._validated, self._non_validated = self._add_validated()
1147
+ self._non_validated, self._synonyms = self._validate(values=self._non_validated)
1148
+
1149
+ # always register new Features if they are columns
1150
+ if self._key == "columns" and self._field == Feature.name:
1151
+ self.add_new()
1152
+
1153
+ def standardize(self) -> None:
1154
+ """Standardize the vector."""
1155
+ registry = self._field.field.model
1156
+ if not hasattr(registry, "standardize"):
1157
+ return self.values
1158
+ if self._synonyms is None:
1159
+ self.validate()
1160
+ # get standardized values
1161
+ std_values = self._replace_synonyms()
1162
+ # update non_validated values
1163
+ self._non_validated = [
1164
+ i for i in self._non_validated if i not in self._synonyms.keys()
1165
+ ]
1166
+ # remove synonyms since they are now standardized
1167
+ self._synonyms = {}
1168
+ # update the values with the standardized values
1169
+ self.values = std_values
1170
+
1171
+ def add_new(self, **create_kwargs) -> None:
1172
+ """Add new values to the registry."""
1173
+ if self._non_validated is None:
1174
+ self.validate()
1175
+ if len(self._synonyms) > 0:
1176
+ # raise error because .standardize modifies the input dataset
1177
+ raise ValidationError(
1178
+ "Please run `.standardize()` before adding new values."
1179
+ )
1180
+ self._add_new(
1181
+ values=self._non_validated,
1182
+ **create_kwargs,
1183
+ )
1184
+ # remove the non_validated values since they are now registered
1185
+ self._non_validated = []
1186
+
1187
+
1188
+ class DataFrameCatManager:
1189
+ """Manage categoricals by updating registries.
1190
+
1191
+ This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
1192
+
1193
+ If you find non-validated values, you have two options:
1194
+
1195
+ - new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.add_new_from`
1196
+ - non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.non_validated` and addressed manually
1197
+ """
1198
+
1199
+ def __init__(
1200
+ self,
1201
+ df: pd.DataFrame | Artifact,
1202
+ columns_field: FieldAttr = Feature.name,
1203
+ columns_names: Iterable[str] | None = None,
1204
+ categoricals: list[Feature] | None = None,
1205
+ sources: dict[str, Record] | None = None,
1206
+ index: Feature | None = None,
1207
+ slot: str | None = None,
1208
+ maximal_set: bool = False,
1209
+ ) -> None:
1210
+ self._non_validated = None
1211
+ self._index = index
1212
+ self._artifact: Artifact = None # pass the dataset as an artifact
1213
+ self._dataset: Any = df # pass the dataset as a UPathStr or data object
1214
+ if isinstance(self._dataset, Artifact):
1215
+ self._artifact = self._dataset
1216
+ self._dataset = self._dataset.load(is_run_input=False)
1217
+ self._is_validated: bool = False
1218
+ self._categoricals = categoricals or []
1219
+ self._non_validated = None
1220
+ self._sources = sources or {}
1221
+ self._columns_field = columns_field
1222
+ self._validate_category_error_messages: str = ""
1223
+ self._cat_vectors: dict[str, CatVector] = {}
1224
+ self._slot = slot
1225
+ self._maximal_set = maximal_set
1226
+
1227
+ if columns_names is None:
1228
+ columns_names = []
1229
+ if columns_field == Feature.name:
1230
+ self._cat_vectors["columns"] = CatVector(
1231
+ values_getter=columns_names,
1232
+ field=columns_field,
1233
+ key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
1234
+ source=self._sources.get("columns"),
1235
+ cat_manager=self,
1236
+ maximal_set=self._maximal_set,
1237
+ )
1238
+ else:
1239
+ self._cat_vectors["columns"] = CatVector(
1240
+ values_getter=lambda: self._dataset.columns, # lambda ensures the inplace update
1241
+ values_setter=lambda new_values: setattr(
1242
+ self._dataset, "columns", pd.Index(new_values)
1243
+ ),
1244
+ field=columns_field,
1245
+ key="columns",
1246
+ source=self._sources.get("columns"),
1247
+ cat_manager=self,
1248
+ maximal_set=self._maximal_set,
1249
+ )
1250
+ for feature in self._categoricals:
1251
+ result = parse_dtype(feature.dtype)[
1252
+ 0
1253
+ ] # TODO: support composite dtypes for categoricals
1254
+ key = feature.name
1255
+ field = result["field"]
1256
+ subtype_str = result["subtype_str"]
1257
+ self._cat_vectors[key] = CatVector(
1258
+ values_getter=lambda k=key: self._dataset[
1259
+ k
1260
+ ], # Capture key as default argument
1261
+ values_setter=lambda new_values, k=key: self._dataset.__setitem__(
1262
+ k, new_values
1263
+ ),
1264
+ field=field,
1265
+ key=key,
1266
+ source=self._sources.get(key),
1267
+ feature=feature,
1268
+ cat_manager=self,
1269
+ subtype_str=subtype_str,
1270
+ )
1271
+ if index is not None and index.dtype.startswith("cat"):
1272
+ result = parse_dtype(index.dtype)[0]
1273
+ field = result["field"]
1274
+ key = "index"
1275
+ self._cat_vectors[key] = CatVector(
1276
+ values_getter=self._dataset.index,
1277
+ field=field,
1278
+ key=key,
1279
+ feature=index,
1280
+ cat_manager=self,
1281
+ )
1282
+
1283
+ @property
1284
+ def non_validated(self) -> dict[str, list[str]]:
1285
+ """Return the non-validated features and labels."""
1286
+ if self._non_validated is None:
1287
+ raise ValidationError("Please run validate() first!")
1288
+ return {
1289
+ key: cat_vector._non_validated
1290
+ for key, cat_vector in self._cat_vectors.items()
1291
+ if cat_vector._non_validated and key != "columns"
1292
+ }
1293
+
1294
+ @property
1295
+ def categoricals(self) -> list[Feature]:
1296
+ """The categorical features."""
1297
+ return self._categoricals
1298
+
1299
+ def lookup(self, public: bool = False) -> CatLookup:
1300
+ """Lookup categories.
1301
+
1302
+ Args:
1303
+ public: If "public", the lookup is performed on the public reference.
1304
+ """
1305
+ return CatLookup(
1306
+ categoricals=self._categoricals,
1307
+ slots={"columns": self._columns_field},
1308
+ public=public,
1309
+ sources=self._sources,
1310
+ )
1311
+
1312
+ def validate(self) -> bool:
1313
+ """Validate variables and categorical observations."""
1314
+ self._validate_category_error_messages = "" # reset the error messages
1315
+
1316
+ validated = True
1317
+ for key, cat_vector in self._cat_vectors.items():
1318
+ logger.info(f"validating vector {key}")
1319
+ cat_vector.validate()
1320
+ validated &= cat_vector.is_validated
1321
+ self._is_validated = validated
1322
+ self._non_validated = {} # type: ignore
1323
+
1324
+ if self._index is not None:
1325
+ # cat_vector.validate() populates validated labels
1326
+ # the index should become part of the feature set corresponding to the dataframe
1327
+ if self._cat_vectors["columns"].records is not None:
1328
+ self._cat_vectors["columns"].records.insert(0, self._index) # type: ignore
1329
+ else:
1330
+ self._cat_vectors["columns"].records = [self._index] # type: ignore
1331
+
1332
+ return self._is_validated
1333
+
1334
+ def standardize(self, key: str) -> None:
1335
+ """Replace synonyms with standardized values.
1336
+
1337
+ Modifies the input dataset inplace.
1338
+
1339
+ Args:
1340
+ key: The key referencing the column in the DataFrame to standardize.
1341
+ """
1342
+ if self._artifact is not None:
1343
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1344
+
1345
+ if key == "all":
1346
+ logger.warning(
1347
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1348
+ )
1349
+ for k in self.non_validated.keys():
1350
+ self._cat_vectors[k].standardize()
1351
+ else:
1352
+ self._cat_vectors[key].standardize()
1353
+
1354
+ def add_new_from(self, key: str, **kwargs):
1355
+ """Add validated & new categories.
1356
+
1357
+ Args:
1358
+ key: The key referencing the slot in the DataFrame from which to draw terms.
1359
+ **kwargs: Additional keyword arguments to pass to create new records
1360
+ """
1361
+ if len(kwargs) > 0 and key == "all":
1362
+ raise ValueError("Cannot pass additional arguments to 'all' key!")
1363
+ if key == "all":
1364
+ logger.warning(
1365
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1366
+ )
1367
+ for k in self.non_validated.keys():
1368
+ self._cat_vectors[k].add_new(**kwargs)
1369
+ else:
1370
+ self._cat_vectors[key].add_new(**kwargs)
1371
+
1372
+
1373
+ def get_current_filter_kwargs(registry: type[Record], kwargs: dict) -> dict:
1374
+ """Make sure the source and organism are saved in the same database as the registry."""
1375
+ db = registry.filter().db
1376
+ source = kwargs.get("source")
1377
+ organism = kwargs.get("organism")
1378
+ filter_kwargs = kwargs.copy()
1379
+
1380
+ if isinstance(organism, Record) and organism._state.db != "default":
1381
+ if db is None or db == "default":
1382
+ organism_default = copy.copy(organism)
1383
+ # save the organism record in the default database
1384
+ organism_default.save()
1385
+ filter_kwargs["organism"] = organism_default
1386
+ if isinstance(source, Record) and source._state.db != "default":
1387
+ if db is None or db == "default":
1388
+ source_default = copy.copy(source)
1389
+ # save the source record in the default database
1390
+ source_default.save()
1391
+ filter_kwargs["source"] = source_default
1392
+
1393
+ return filter_kwargs
1394
+
1395
+
1396
+ def get_organism_kwargs(
1397
+ field: FieldAttr, organism: str | None = None, values: Any = None
1398
+ ) -> dict[str, str]:
1399
+ """Check if a registry needs an organism and return the organism name."""
1400
+ registry = field.field.model
1401
+ if registry.__base__.__name__ == "BioRecord":
1402
+ import bionty as bt
1403
+ from bionty._organism import is_organism_required
1404
+
1405
+ from ..models._from_values import get_organism_record_from_field
1406
+
1407
+ if is_organism_required(registry):
1408
+ if organism is not None or bt.settings.organism is not None:
1409
+ return {"organism": organism or bt.settings.organism.name}
1410
+ else:
1411
+ organism_record = get_organism_record_from_field(
1412
+ field, organism=organism, values=values
1413
+ )
1414
+ if organism_record is not None:
1415
+ return {"organism": organism_record.name}
1416
+ return {}
1417
+
1418
+
1419
+ def annotate_artifact(
1420
+ artifact: Artifact,
1421
+ *,
1422
+ curator: AnnDataCurator | SlotsCurator | None = None,
1423
+ cat_vectors: dict[str, CatVector] | None = None,
1424
+ ) -> Artifact:
1425
+ from .. import settings
1426
+ from ..models.artifact import add_labels
1427
+
1428
+ if cat_vectors is None:
1429
+ cat_vectors = {}
1430
+
1431
+ # annotate with labels
1432
+ for key, cat_vector in cat_vectors.items():
1433
+ if (
1434
+ cat_vector._field.field.model == Feature
1435
+ or key == "columns"
1436
+ or key == "var_index"
1437
+ ):
1438
+ continue
1439
+ if len(cat_vector.records) > settings.annotation.n_max_records:
1440
+ logger.important(
1441
+ f"not annotating with {len(cat_vector.records)} labels for feature {key} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1442
+ )
1443
+ continue
1444
+ add_labels(
1445
+ artifact,
1446
+ records=cat_vector.records,
1447
+ feature=cat_vector.feature,
1448
+ feature_ref_is_name=None, # do not need anymore
1449
+ label_ref_is_name=cat_vector.label_ref_is_name,
1450
+ from_curator=True,
1451
+ )
1452
+
1453
+ # annotate with inferred schemas aka feature sets
1454
+ if artifact.otype == "DataFrame":
1455
+ features = cat_vectors["columns"].records
1456
+ if features is not None:
1457
+ feature_set = Schema(
1458
+ features=features, coerce_dtype=artifact.schema.coerce_dtype
1459
+ ) # TODO: add more defaults from validating schema
1460
+ if (
1461
+ feature_set._state.adding
1462
+ and len(features) > settings.annotation.n_max_records
1463
+ ):
1464
+ logger.important(
1465
+ f"not annotating with {len(features)} features as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1466
+ )
1467
+ itype = parse_cat_dtype(artifact.schema.itype, is_itype=True)["field"]
1468
+ feature_set = Schema(itype=itype, n=len(features))
1469
+ artifact.feature_sets.add(
1470
+ feature_set.save(), through_defaults={"slot": "columns"}
1471
+ )
1472
+ else:
1473
+ for slot, slot_curator in curator._slots.items():
1474
+ # var_index is backward compat (2025-05-01)
1475
+ name = (
1476
+ "var_index"
1477
+ if (slot == "var" and "var_index" in slot_curator.cat._cat_vectors)
1478
+ else "columns"
1479
+ )
1480
+ features = slot_curator.cat._cat_vectors[name].records
1481
+ if features is None:
1482
+ logger.warning(f"no features found for slot {slot}")
1483
+ continue
1484
+ itype = parse_cat_dtype(artifact.schema.slots[slot].itype, is_itype=True)[
1485
+ "field"
1486
+ ]
1487
+ feature_set = Schema(features=features, itype=itype)
1488
+ if (
1489
+ feature_set._state.adding
1490
+ and len(features) > settings.annotation.n_max_records
1491
+ ):
1492
+ logger.important(
1493
+ f"not annotating with {len(features)} features for slot {slot} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1494
+ )
1495
+ feature_set = Schema(itype=itype, n=len(features))
1496
+ artifact.feature_sets.add(
1497
+ feature_set.save(), through_defaults={"slot": slot}
1498
+ )
1499
+
1500
+ slug = ln_setup.settings.instance.slug
1501
+ if ln_setup.settings.instance.is_remote: # pdagma: no cover
1502
+ logger.important(f"go to https://lamin.ai/{slug}/artifact/{artifact.uid}")
1503
+ return artifact
1504
+
1505
+
1506
+ # TODO: need this function to support mutli-value columns
1507
+ def _flatten_unique(series: pd.Series[list[Any] | Any]) -> list[Any]:
1508
+ """Flatten a Pandas series containing lists or single items into a unique list of elements."""
1509
+ result = set()
1510
+
1511
+ for item in series:
1512
+ if isinstance(item, list):
1513
+ result.update(item)
1514
+ else:
1515
+ result.add(item)
1516
+
1517
+ return list(result)
1518
+
1519
+
1520
+ def _save_organism(name: str):
1521
+ """Save an organism record."""
1522
+ import bionty as bt
1523
+
1524
+ organism = bt.Organism.filter(name=name).one_or_none()
1525
+ if organism is None:
1526
+ organism = bt.Organism.from_source(name=name)
1527
+ if organism is None:
1528
+ raise ValidationError(
1529
+ f'Organism "{name}" not found from public reference\n'
1530
+ f' → please save it from a different source: bt.Organism.from_source(name="{name}", source).save()'
1531
+ f' → or manually save it without source: bt.Organism(name="{name}").save()'
1532
+ )
1533
+ organism.save()
1534
+ return organism