lamindb 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. lamindb/__init__.py +52 -36
  2. lamindb/_finish.py +17 -10
  3. lamindb/_tracked.py +1 -1
  4. lamindb/base/__init__.py +3 -1
  5. lamindb/base/fields.py +40 -22
  6. lamindb/base/ids.py +1 -94
  7. lamindb/base/types.py +2 -0
  8. lamindb/base/uids.py +117 -0
  9. lamindb/core/_context.py +216 -133
  10. lamindb/core/_settings.py +38 -25
  11. lamindb/core/datasets/__init__.py +11 -4
  12. lamindb/core/datasets/_core.py +5 -5
  13. lamindb/core/datasets/_small.py +0 -93
  14. lamindb/core/datasets/mini_immuno.py +172 -0
  15. lamindb/core/loaders.py +1 -1
  16. lamindb/core/storage/_backed_access.py +100 -6
  17. lamindb/core/storage/_polars_lazy_df.py +51 -0
  18. lamindb/core/storage/_pyarrow_dataset.py +15 -30
  19. lamindb/core/storage/objects.py +6 -0
  20. lamindb/core/subsettings/__init__.py +2 -0
  21. lamindb/core/subsettings/_annotation_settings.py +11 -0
  22. lamindb/curators/__init__.py +7 -3559
  23. lamindb/curators/_legacy.py +2056 -0
  24. lamindb/curators/core.py +1546 -0
  25. lamindb/errors.py +11 -0
  26. lamindb/examples/__init__.py +27 -0
  27. lamindb/examples/schemas/__init__.py +12 -0
  28. lamindb/examples/schemas/_anndata.py +25 -0
  29. lamindb/examples/schemas/_simple.py +19 -0
  30. lamindb/integrations/_vitessce.py +8 -5
  31. lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
  32. lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
  33. lamindb/models/__init__.py +12 -2
  34. lamindb/models/_describe.py +21 -4
  35. lamindb/models/_feature_manager.py +384 -301
  36. lamindb/models/_from_values.py +1 -1
  37. lamindb/models/_is_versioned.py +5 -15
  38. lamindb/models/_label_manager.py +8 -2
  39. lamindb/models/artifact.py +354 -177
  40. lamindb/models/artifact_set.py +122 -0
  41. lamindb/models/can_curate.py +4 -1
  42. lamindb/models/collection.py +79 -56
  43. lamindb/models/core.py +1 -1
  44. lamindb/models/feature.py +78 -47
  45. lamindb/models/has_parents.py +24 -9
  46. lamindb/models/project.py +3 -3
  47. lamindb/models/query_manager.py +221 -22
  48. lamindb/models/query_set.py +251 -206
  49. lamindb/models/record.py +211 -344
  50. lamindb/models/run.py +59 -5
  51. lamindb/models/save.py +9 -5
  52. lamindb/models/schema.py +673 -196
  53. lamindb/models/transform.py +5 -14
  54. lamindb/models/ulabel.py +8 -5
  55. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/METADATA +8 -7
  56. lamindb-1.5.0.dist-info/RECORD +108 -0
  57. lamindb-1.3.2.dist-info/RECORD +0 -95
  58. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/LICENSE +0 -0
  59. {lamindb-1.3.2.dist-info → lamindb-1.5.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1546 @@
1
+ """Curator utilities.
2
+
3
+ .. autosummary::
4
+ :toctree: .
5
+
6
+ Curator
7
+ SlotsCurator
8
+ CatVector
9
+ CatLookup
10
+ DataFrameCatManager
11
+
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import re
18
+ from typing import TYPE_CHECKING, Any, Callable
19
+
20
+ import lamindb_setup as ln_setup
21
+ import numpy as np
22
+ import pandas as pd
23
+ import pandera
24
+ from lamin_utils import colors, logger
25
+ from lamindb_setup.core._docs import doc_args
26
+
27
+ from lamindb.base.types import FieldAttr # noqa
28
+ from lamindb.models import (
29
+ Artifact,
30
+ Feature,
31
+ Record,
32
+ Run,
33
+ Schema,
34
+ )
35
+ from lamindb.models._from_values import _format_values
36
+ from lamindb.models.artifact import (
37
+ data_is_anndata,
38
+ data_is_mudata,
39
+ data_is_spatialdata,
40
+ )
41
+ from lamindb.models.feature import parse_cat_dtype, parse_dtype
42
+
43
+ from ..errors import InvalidArgument, ValidationError
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import Iterable
47
+ from typing import Any
48
+
49
+ from anndata import AnnData
50
+ from mudata import MuData
51
+ from spatialdata import SpatialData
52
+
53
+ from lamindb.models.query_set import RecordList
54
+
55
+
56
+ def strip_ansi_codes(text):
57
+ # This pattern matches ANSI escape sequences
58
+ ansi_pattern = re.compile(r"\x1b\[[0-9;]*m")
59
+ return ansi_pattern.sub("", text)
60
+
61
+
62
+ class CatLookup:
63
+ """Lookup categories from the reference instance.
64
+
65
+ Args:
66
+ categoricals: A dictionary of categorical fields to lookup.
67
+ slots: A dictionary of slot fields to lookup.
68
+ public: Whether to lookup from the public instance. Defaults to False.
69
+
70
+ Example::
71
+
72
+ curator = ln.curators.DataFrameCurator(...)
73
+ curator.cat.lookup()["cell_type"].alveolar_type_1_fibroblast_cell
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ categoricals: list[Feature] | dict[str, FieldAttr],
80
+ slots: dict[str, FieldAttr] = None,
81
+ public: bool = False,
82
+ sources: dict[str, Record] | None = None,
83
+ ) -> None:
84
+ slots = slots or {}
85
+ if isinstance(categoricals, list):
86
+ categoricals = {
87
+ feature.name: parse_dtype(feature.dtype)[0]["field"]
88
+ for feature in categoricals
89
+ }
90
+ self._categoricals = {**categoricals, **slots}
91
+ self._public = public
92
+ self._sources = sources
93
+
94
+ def __getattr__(self, name):
95
+ if name in self._categoricals:
96
+ registry = self._categoricals[name].field.model
97
+ if self._public and hasattr(registry, "public"):
98
+ return registry.public(source=self._sources.get(name)).lookup()
99
+ else:
100
+ return registry.lookup()
101
+ raise AttributeError(
102
+ f'"{self.__class__.__name__}" object has no attribute "{name}"'
103
+ )
104
+
105
+ def __getitem__(self, name):
106
+ if name in self._categoricals:
107
+ registry = self._categoricals[name].field.model
108
+ if self._public and hasattr(registry, "public"):
109
+ return registry.public(source=self._sources.get(name)).lookup()
110
+ else:
111
+ return registry.lookup()
112
+ raise AttributeError(
113
+ f'"{self.__class__.__name__}" object has no attribute "{name}"'
114
+ )
115
+
116
+ def __repr__(self) -> str:
117
+ if len(self._categoricals) > 0:
118
+ getattr_keys = "\n ".join(
119
+ [f".{key}" for key in self._categoricals if key.isidentifier()]
120
+ )
121
+ getitem_keys = "\n ".join(
122
+ [str([key]) for key in self._categoricals if not key.isidentifier()]
123
+ )
124
+ ref = "public" if self._public else "registries"
125
+ return (
126
+ f"Lookup objects from the {colors.italic(ref)}:\n "
127
+ f"{colors.green(getattr_keys)}\n "
128
+ f"{colors.green(getitem_keys)}\n"
129
+ 'Example:\n → categories = curator.lookup()["cell_type"]\n'
130
+ " → categories.alveolar_type_1_fibroblast_cell\n\n"
131
+ "To look up public ontologies, use .lookup(public=True)"
132
+ )
133
+ else: # pragma: no cover
134
+ return colors.warning("No fields are found!")
135
+
136
+
137
+ CAT_MANAGER_DOCSTRING = """Manage categoricals by updating registries."""
138
+
139
+
140
+ SLOTS_DOCSTRING = """Access sub curators by slot."""
141
+
142
+
143
+ VALIDATE_DOCSTRING = """Validate dataset against Schema.
144
+
145
+ Raises:
146
+ lamindb.errors.ValidationError: If validation fails.
147
+ """
148
+
149
+ SAVE_ARTIFACT_DOCSTRING = """Save an annotated artifact.
150
+
151
+ Args:
152
+ key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
153
+ description: A description.
154
+ revises: Previous version of the artifact. Is an alternative way to passing `key` to trigger a new version.
155
+ run: The run that creates the artifact.
156
+
157
+ Returns:
158
+ A saved artifact record.
159
+ """
160
+
161
+
162
+ class Curator:
163
+ """Curator base class.
164
+
165
+ A `Curator` object makes it easy to validate, standardize & annotate datasets.
166
+
167
+ See:
168
+ - :class:`~lamindb.curators.DataFrameCurator`
169
+ - :class:`~lamindb.curators.AnnDataCurator`
170
+ - :class:`~lamindb.curators.MuDataCurator`
171
+ - :class:`~lamindb.curators.SpatialDataCurator`
172
+ """
173
+
174
+ def __init__(self, dataset: Any, schema: Schema | None = None):
175
+ self._artifact: Artifact = None # pass the dataset as an artifact
176
+ self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
177
+ if isinstance(self._dataset, Artifact):
178
+ self._artifact = self._dataset
179
+ if self._artifact.otype in {
180
+ "DataFrame",
181
+ "AnnData",
182
+ "MuData",
183
+ "SpatialData",
184
+ }:
185
+ self._dataset = self._dataset.load(is_run_input=False)
186
+ self._schema: Schema | None = schema
187
+ self._is_validated: bool = False
188
+
189
+ @doc_args(VALIDATE_DOCSTRING)
190
+ def validate(self) -> bool | str:
191
+ """{}""" # noqa: D415
192
+ pass # pragma: no cover
193
+
194
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
195
+ def save_artifact(
196
+ self,
197
+ *,
198
+ key: str | None = None,
199
+ description: str | None = None,
200
+ revises: Artifact | None = None,
201
+ run: Run | None = None,
202
+ ) -> Artifact:
203
+ """{}""" # noqa: D415
204
+ # Note that this docstring has to be consistent with the Artifact()
205
+ # constructor signature
206
+ pass # pragma: no cover
207
+
208
+ def __repr__(self) -> str:
209
+ from lamin_utils import colors
210
+
211
+ if self._schema is not None:
212
+ # Schema might have different attributes
213
+ if hasattr(self._schema, "name") and self._schema.name:
214
+ schema_str = colors.italic(self._schema.name)
215
+ elif hasattr(self._schema, "uid"):
216
+ schema_str = colors.italic(f"uid={self._schema.uid}")
217
+ elif hasattr(self._schema, "id"):
218
+ schema_str = colors.italic(f"id={self._schema.id}")
219
+ else:
220
+ schema_str = colors.italic("unnamed")
221
+
222
+ # Add schema type info if available
223
+ if hasattr(self._schema, "otype") and self._schema.otype:
224
+ schema_str += f" ({self._schema.otype})"
225
+ else:
226
+ schema_str = colors.warning("None")
227
+
228
+ status_str = ""
229
+ if self._is_validated:
230
+ status_str = f", {colors.green('validated')}"
231
+ else:
232
+ status_str = f", {colors.yellow('unvalidated')}"
233
+
234
+ cls_name = colors.green(self.__class__.__name__)
235
+
236
+ # Get additional info based on curator type
237
+ extra_info = ""
238
+ if hasattr(self, "_slots") and self._slots:
239
+ # For SlotsCurator and its subclasses
240
+ slots_count = len(self._slots)
241
+ if slots_count > 0:
242
+ slot_names = list(self._slots.keys())
243
+ if len(slot_names) <= 3:
244
+ extra_info = f", slots: {slot_names}"
245
+ else:
246
+ extra_info = f", slots: [{', '.join(slot_names[:3])}... +{len(slot_names) - 3} more]"
247
+ elif (
248
+ cls_name == "DataFrameCurator"
249
+ and hasattr(self, "cat")
250
+ and hasattr(self.cat, "_categoricals")
251
+ ):
252
+ # For DataFrameCurator
253
+ cat_count = len(getattr(self.cat, "_categoricals", []))
254
+ if cat_count > 0:
255
+ extra_info = f", categorical_features={cat_count}"
256
+
257
+ artifact_info = ""
258
+ if self._artifact is not None:
259
+ artifact_uid = getattr(self._artifact, "uid", str(self._artifact))
260
+ short_uid = (
261
+ str(artifact_uid)[:8] + "..."
262
+ if len(str(artifact_uid)) > 8
263
+ else str(artifact_uid)
264
+ )
265
+ artifact_info = f", artifact: {colors.italic(short_uid)}"
266
+
267
+ return (
268
+ f"{cls_name}{artifact_info}(Schema: {schema_str}{extra_info}{status_str})"
269
+ )
270
+
271
+
272
+ # default implementation for AnnDataCurator, MuDataCurator, and SpatialDataCurator
273
+ class SlotsCurator(Curator):
274
+ """Curator for a dataset with slots.
275
+
276
+ Args:
277
+ dataset: The dataset to validate & annotate.
278
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
279
+
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ dataset: Any,
285
+ schema: Schema,
286
+ ) -> None:
287
+ super().__init__(dataset=dataset, schema=schema)
288
+ self._slots: dict[str, DataFrameCurator] = {}
289
+
290
+ # used in MuDataCurator and SpatialDataCurator
291
+ # in form of {table/modality_key: var_field}
292
+ self._var_fields: dict[str, FieldAttr] = {}
293
+ # in form of {table/modality_key: categoricals}
294
+ self._cat_vectors: dict[str, dict[str, CatVector]] = {}
295
+
296
+ @property
297
+ @doc_args(SLOTS_DOCSTRING)
298
+ def slots(self) -> dict[str, DataFrameCurator]:
299
+ """{}""" # noqa: D415
300
+ return self._slots
301
+
302
+ @doc_args(VALIDATE_DOCSTRING)
303
+ def validate(self) -> None:
304
+ """{}""" # noqa: D415
305
+ for slot, curator in self._slots.items():
306
+ logger.info(f"validating slot {slot} ...")
307
+ curator.validate()
308
+
309
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
310
+ def save_artifact(
311
+ self,
312
+ *,
313
+ key: str | None = None,
314
+ description: str | None = None,
315
+ revises: Artifact | None = None,
316
+ run: Run | None = None,
317
+ ) -> Artifact:
318
+ """{}""" # noqa: D415
319
+ if not self._is_validated:
320
+ self.validate()
321
+ if self._artifact is None:
322
+ if data_is_anndata(self._dataset):
323
+ self._artifact = Artifact.from_anndata(
324
+ self._dataset,
325
+ key=key,
326
+ description=description,
327
+ revises=revises,
328
+ run=run,
329
+ )
330
+ if data_is_mudata(self._dataset):
331
+ self._artifact = Artifact.from_mudata(
332
+ self._dataset,
333
+ key=key,
334
+ description=description,
335
+ revises=revises,
336
+ run=run,
337
+ )
338
+ elif data_is_spatialdata(self._dataset):
339
+ self._artifact = Artifact.from_spatialdata(
340
+ self._dataset,
341
+ key=key,
342
+ description=description,
343
+ revises=revises,
344
+ run=run,
345
+ )
346
+ self._artifact.schema = self._schema
347
+ self._artifact.save()
348
+ cat_vectors = {}
349
+ for curator in self._slots.values():
350
+ for key, cat_vector in curator.cat._cat_vectors.items():
351
+ cat_vectors[key] = cat_vector
352
+ return annotate_artifact( # type: ignore
353
+ self._artifact,
354
+ curator=self,
355
+ cat_vectors=cat_vectors,
356
+ )
357
+
358
+
359
+ def check_dtype(expected_type) -> Callable:
360
+ """Creates a check function for Pandera that validates a column's dtype.
361
+
362
+ Args:
363
+ expected_type: String identifier for the expected type ('int', 'float', or 'num')
364
+
365
+ Returns:
366
+ A function that checks if a series has the expected dtype
367
+ """
368
+
369
+ def check_function(series):
370
+ if expected_type == "int":
371
+ is_valid = pd.api.types.is_integer_dtype(series.dtype)
372
+ elif expected_type == "float":
373
+ is_valid = pd.api.types.is_float_dtype(series.dtype)
374
+ elif expected_type == "num":
375
+ is_valid = pd.api.types.is_numeric_dtype(series.dtype)
376
+ return is_valid
377
+
378
+ return check_function
379
+
380
+
381
+ # this is also currently used as DictCurator
382
+ class DataFrameCurator(Curator):
383
+ # the example in the docstring is tested in test_curators_quickstart_example
384
+ """Curator for `DataFrame`.
385
+
386
+ Args:
387
+ dataset: The DataFrame-like object to validate & annotate.
388
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
389
+ slot: Indicate the slot in a composite curator for a composite data structure.
390
+
391
+ Example:
392
+
393
+ For simple example using a flexible schema, see :meth:`~lamindb.Artifact.from_df`.
394
+
395
+ Here is an example that enforces a minimal set of columns in the dataframe.
396
+
397
+ .. literalinclude:: scripts/curate_dataframe_minimal_errors.py
398
+ :language: python
399
+
400
+ Under-the-hood, this used the following schema.
401
+
402
+ .. literalinclude:: scripts/define_mini_immuno_schema_flexible.py
403
+ :language: python
404
+
405
+ Valid features & labels were defined as:
406
+
407
+ .. literalinclude:: scripts/define_mini_immuno_features_labels.py
408
+ :language: python
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ dataset: pd.DataFrame | Artifact,
414
+ schema: Schema,
415
+ slot: str | None = None,
416
+ ) -> None:
417
+ super().__init__(dataset=dataset, schema=schema)
418
+ categoricals = []
419
+ features = []
420
+ feature_ids: set[int] = set()
421
+ if schema.flexible:
422
+ features += Feature.filter(name__in=self._dataset.keys()).list()
423
+ feature_ids = {feature.id for feature in features}
424
+ if schema.n > 0:
425
+ if schema._index_feature_uid is not None:
426
+ schema_features = [
427
+ feature
428
+ for feature in schema.members.list()
429
+ if feature.uid != schema._index_feature_uid # type: ignore
430
+ ]
431
+ else:
432
+ schema_features = schema.members.list() # type: ignore
433
+ if feature_ids:
434
+ features.extend(
435
+ feature
436
+ for feature in schema_features
437
+ if feature.id not in feature_ids # type: ignore
438
+ )
439
+ else:
440
+ features.extend(schema_features)
441
+ else:
442
+ assert schema.itype is not None # noqa: S101
443
+ pandera_columns = {}
444
+ if features or schema._index_feature_uid is not None:
445
+ # populate features
446
+ if schema.minimal_set:
447
+ optional_feature_uids = set(schema.optionals.get_uids())
448
+ for feature in features:
449
+ if schema.minimal_set:
450
+ required = feature.uid not in optional_feature_uids
451
+ else:
452
+ required = False
453
+ if feature.dtype in {"int", "float", "num"}:
454
+ if isinstance(self._dataset, pd.DataFrame):
455
+ dtype = (
456
+ self._dataset[feature.name].dtype
457
+ if feature.name in self._dataset.keys()
458
+ else None
459
+ )
460
+ else:
461
+ dtype = None
462
+ pandera_columns[feature.name] = pandera.Column(
463
+ dtype=None,
464
+ checks=pandera.Check(
465
+ check_dtype(feature.dtype),
466
+ element_wise=False,
467
+ error=f"Column '{feature.name}' failed dtype check for '{feature.dtype}': got {dtype}",
468
+ ),
469
+ nullable=feature.nullable,
470
+ coerce=feature.coerce_dtype,
471
+ required=required,
472
+ )
473
+ else:
474
+ pandera_dtype = (
475
+ feature.dtype
476
+ if not feature.dtype.startswith("cat")
477
+ else "category"
478
+ )
479
+ pandera_columns[feature.name] = pandera.Column(
480
+ pandera_dtype,
481
+ nullable=feature.nullable,
482
+ coerce=feature.coerce_dtype,
483
+ required=required,
484
+ )
485
+ if feature.dtype.startswith("cat"):
486
+ # validate categoricals if the column is required or if the column is present
487
+ if required or feature.name in self._dataset.keys():
488
+ categoricals.append(feature)
489
+ if schema._index_feature_uid is not None:
490
+ # in almost no case, an index should have a pandas.CategoricalDtype in a DataFrame
491
+ # so, we're typing it as `str` here
492
+ index = pandera.Index(
493
+ schema.index.dtype
494
+ if not schema.index.dtype.startswith("cat")
495
+ else str
496
+ )
497
+ else:
498
+ index = None
499
+ self._pandera_schema = pandera.DataFrameSchema(
500
+ pandera_columns,
501
+ coerce=schema.coerce_dtype,
502
+ strict=schema.maximal_set,
503
+ ordered=schema.ordered_set,
504
+ index=index,
505
+ )
506
+ self._cat_manager = DataFrameCatManager(
507
+ self._dataset,
508
+ columns_field=parse_cat_dtype(schema.itype, is_itype=True)["field"],
509
+ columns_names=pandera_columns.keys(),
510
+ categoricals=categoricals,
511
+ index=schema.index,
512
+ slot=slot,
513
+ schema_maximal_set=schema.maximal_set,
514
+ )
515
+
516
+ @property
517
+ @doc_args(CAT_MANAGER_DOCSTRING)
518
+ def cat(self) -> DataFrameCatManager:
519
+ """{}""" # noqa: D415
520
+ return self._cat_manager
521
+
522
+ def standardize(self) -> None:
523
+ """Standardize the dataset.
524
+
525
+ - Adds missing columns for features
526
+ - Fills missing values for features with default values
527
+ """
528
+ for feature in self._schema.members:
529
+ if feature.name not in self._dataset.columns:
530
+ if feature.default_value is not None or feature.nullable:
531
+ fill_value = (
532
+ feature.default_value
533
+ if feature.default_value is not None
534
+ else pd.NA
535
+ )
536
+ if feature.dtype.startswith("cat"):
537
+ self._dataset[feature.name] = pd.Categorical(
538
+ [fill_value] * len(self._dataset)
539
+ )
540
+ else:
541
+ self._dataset[feature.name] = fill_value
542
+ logger.important(
543
+ f"added column {feature.name} with fill value {fill_value}"
544
+ )
545
+ else:
546
+ raise ValidationError(
547
+ f"Missing column {feature.name} cannot be added because is not nullable and has no default value"
548
+ )
549
+ else:
550
+ if feature.default_value is not None:
551
+ if isinstance(
552
+ self._dataset[feature.name].dtype, pd.CategoricalDtype
553
+ ):
554
+ if (
555
+ feature.default_value
556
+ not in self._dataset[feature.name].cat.categories
557
+ ):
558
+ self._dataset[feature.name] = self._dataset[
559
+ feature.name
560
+ ].cat.add_categories(feature.default_value)
561
+ self._dataset[feature.name] = self._dataset[feature.name].fillna(
562
+ feature.default_value
563
+ )
564
+
565
+ def _cat_manager_validate(self) -> None:
566
+ self.cat.validate()
567
+
568
+ if self.cat._is_validated:
569
+ self._is_validated = True
570
+ else:
571
+ self._is_validated = False
572
+ raise ValidationError(self.cat._validate_category_error_messages)
573
+
574
+ @doc_args(VALIDATE_DOCSTRING)
575
+ def validate(self) -> None:
576
+ """{}""" # noqa: D415
577
+ if self._schema.n > 0:
578
+ try:
579
+ # first validate through pandera
580
+ self._pandera_schema.validate(self._dataset)
581
+ # then validate lamindb categoricals
582
+ self._cat_manager_validate()
583
+ except pandera.errors.SchemaError as err:
584
+ self._is_validated = False
585
+ # .exconly() doesn't exist on SchemaError
586
+ raise ValidationError(str(err)) from err
587
+ else:
588
+ self._cat_manager_validate()
589
+
590
+ @doc_args(SAVE_ARTIFACT_DOCSTRING)
591
+ def save_artifact(
592
+ self,
593
+ *,
594
+ key: str | None = None,
595
+ description: str | None = None,
596
+ revises: Artifact | None = None,
597
+ run: Run | None = None,
598
+ ) -> Artifact:
599
+ """{}""" # noqa: D415
600
+ if not self._is_validated:
601
+ self.validate() # raises ValidationError if doesn't validate
602
+ if self._artifact is None:
603
+ self._artifact = Artifact.from_df(
604
+ self._dataset,
605
+ key=key,
606
+ description=description,
607
+ revises=revises,
608
+ run=run,
609
+ format=".csv" if key.endswith(".csv") else None,
610
+ )
611
+ self._artifact.schema = self._schema
612
+ self._artifact.save()
613
+ return annotate_artifact( # type: ignore
614
+ self._artifact,
615
+ cat_vectors=self.cat._cat_vectors,
616
+ )
617
+
618
+
619
+ class AnnDataCurator(SlotsCurator):
620
+ """Curator for `AnnData`.
621
+
622
+ Args:
623
+ dataset: The AnnData-like object to validate & annotate.
624
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
625
+
626
+ Example:
627
+
628
+ See :meth:`~lamindb.Artifact.from_anndata`.
629
+
630
+ """
631
+
632
+ def __init__(
633
+ self,
634
+ dataset: AnnData | Artifact,
635
+ schema: Schema,
636
+ ) -> None:
637
+ super().__init__(dataset=dataset, schema=schema)
638
+ if not data_is_anndata(self._dataset):
639
+ raise InvalidArgument("dataset must be AnnData-like.")
640
+ if schema.otype != "AnnData":
641
+ raise InvalidArgument("Schema otype must be 'AnnData'.")
642
+ self._slots = {
643
+ slot: DataFrameCurator(
644
+ (
645
+ getattr(self._dataset, slot.strip(".T")).T
646
+ if slot == "var.T"
647
+ or (
648
+ # backward compat
649
+ slot == "var"
650
+ and schema.slots["var"].itype not in {None, "Feature"}
651
+ )
652
+ else getattr(self._dataset, slot)
653
+ ),
654
+ slot_schema,
655
+ slot=slot,
656
+ )
657
+ for slot, slot_schema in schema.slots.items()
658
+ if slot in {"obs", "var", "var.T", "uns"}
659
+ }
660
+ if "var" in self._slots and schema.slots["var"].itype not in {None, "Feature"}:
661
+ logger.warning(
662
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
663
+ )
664
+ self._slots["var"].cat._cat_vectors["var_index"] = self._slots[
665
+ "var"
666
+ ].cat._cat_vectors.pop("columns")
667
+ self._slots["var"].cat._cat_vectors["var_index"]._key = "var_index"
668
+
669
+
670
+ def _assign_var_fields_categoricals_multimodal(
671
+ modality: str | None,
672
+ slot_type: str,
673
+ slot: str,
674
+ slot_schema: Schema,
675
+ var_fields: dict[str, FieldAttr],
676
+ cat_vectors: dict[str, dict[str, CatVector]],
677
+ slots: dict[str, DataFrameCurator],
678
+ ) -> None:
679
+ """Assigns var_fields and categoricals for multimodal data curators."""
680
+ if modality is not None:
681
+ # Makes sure that all tables are present
682
+ var_fields[modality] = None
683
+ cat_vectors[modality] = {}
684
+
685
+ if slot_type == "var":
686
+ var_field = parse_cat_dtype(slot_schema.itype, is_itype=True)["field"]
687
+ if modality is None:
688
+ # This should rarely/never be used since tables should have different var fields
689
+ var_fields[slot] = var_field # pragma: no cover
690
+ else:
691
+ # Note that this is NOT nested since the nested key is always "var"
692
+ var_fields[modality] = var_field
693
+ else:
694
+ obs_fields = slots[slot].cat._cat_vectors
695
+ if modality is None:
696
+ cat_vectors[slot] = obs_fields
697
+ else:
698
+ # Note that this is NOT nested since the nested key is always "obs"
699
+ cat_vectors[modality] = obs_fields
700
+
701
+
702
+ class MuDataCurator(SlotsCurator):
703
+ """Curator for `MuData`.
704
+
705
+ Args:
706
+ dataset: The MuData-like object to validate & annotate.
707
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
708
+
709
+ Example:
710
+
711
+ .. literalinclude:: scripts/curate-mudata.py
712
+ :language: python
713
+ :caption: curate-mudata.py
714
+ """
715
+
716
+ def __init__(
717
+ self,
718
+ dataset: MuData | Artifact,
719
+ schema: Schema,
720
+ ) -> None:
721
+ super().__init__(dataset=dataset, schema=schema)
722
+ if not data_is_mudata(self._dataset):
723
+ raise InvalidArgument("dataset must be MuData-like.")
724
+ if schema.otype != "MuData":
725
+ raise InvalidArgument("Schema otype must be 'MuData'.")
726
+
727
+ for slot, slot_schema in schema.slots.items():
728
+ if ":" in slot:
729
+ modality, modality_slot = slot.split(":")
730
+ schema_dataset = self._dataset.__getitem__(modality)
731
+ else:
732
+ modality, modality_slot = None, slot
733
+ schema_dataset = self._dataset
734
+ if modality_slot == "var" and schema.slots[slot].itype not in {
735
+ None,
736
+ "Feature",
737
+ }:
738
+ logger.warning(
739
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
740
+ )
741
+ self._slots[slot] = DataFrameCurator(
742
+ (
743
+ getattr(schema_dataset, modality_slot.rstrip(".T")).T
744
+ if modality_slot == "var.T"
745
+ or (
746
+ # backward compat
747
+ modality_slot == "var"
748
+ and schema.slots[slot].itype not in {None, "Feature"}
749
+ )
750
+ else getattr(schema_dataset, modality_slot)
751
+ ),
752
+ slot_schema,
753
+ )
754
+ _assign_var_fields_categoricals_multimodal(
755
+ modality=modality,
756
+ slot_type=modality_slot,
757
+ slot=slot,
758
+ slot_schema=slot_schema,
759
+ var_fields=self._var_fields,
760
+ cat_vectors=self._cat_vectors,
761
+ slots=self._slots,
762
+ )
763
+ self._columns_field = self._var_fields
764
+
765
+
766
+ class SpatialDataCurator(SlotsCurator):
767
+ """Curator for `SpatialData`.
768
+
769
+ Args:
770
+ dataset: The SpatialData-like object to validate & annotate.
771
+ schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
772
+
773
+ Example:
774
+
775
+ See :meth:`~lamindb.Artifact.from_spatialdata`.
776
+ """
777
+
778
+ def __init__(
779
+ self,
780
+ dataset: SpatialData | Artifact,
781
+ schema: Schema,
782
+ *,
783
+ sample_metadata_key: str | None = "sample",
784
+ ) -> None:
785
+ super().__init__(dataset=dataset, schema=schema)
786
+ if not data_is_spatialdata(self._dataset):
787
+ raise InvalidArgument("dataset must be SpatialData-like.")
788
+ if schema.otype != "SpatialData":
789
+ raise InvalidArgument("Schema otype must be 'SpatialData'.")
790
+
791
+ for slot, slot_schema in schema.slots.items():
792
+ split_result = slot.split(":")
793
+ if (len(split_result) == 2 and split_result[0] == "table") or (
794
+ len(split_result) == 3 and split_result[0] == "tables"
795
+ ):
796
+ if len(split_result) == 2:
797
+ table_key, sub_slot = split_result
798
+ logger.warning(
799
+ f"please prefix slot {slot} with 'tables:' going forward"
800
+ )
801
+ else:
802
+ table_key, sub_slot = split_result[1], split_result[2]
803
+ slot_object = self._dataset.tables.__getitem__(table_key)
804
+ if sub_slot == "var" and schema.slots[slot].itype not in {
805
+ None,
806
+ "Feature",
807
+ }:
808
+ logger.warning(
809
+ "auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
810
+ )
811
+ data_object = (
812
+ getattr(slot_object, sub_slot.rstrip(".T")).T
813
+ if sub_slot == "var.T"
814
+ or (
815
+ # backward compat
816
+ sub_slot == "var"
817
+ and schema.slots[slot].itype not in {None, "Feature"}
818
+ )
819
+ else getattr(slot_object, sub_slot)
820
+ )
821
+ elif len(split_result) == 1 or (
822
+ len(split_result) > 1 and split_result[0] == "attrs"
823
+ ):
824
+ table_key = None
825
+ if len(split_result) == 1:
826
+ if split_result[0] != "attrs":
827
+ logger.warning(
828
+ f"please prefix slot {slot} with 'attrs:' going forward"
829
+ )
830
+ sub_slot = slot
831
+ data_object = self._dataset.attrs[slot]
832
+ else:
833
+ sub_slot = "attrs"
834
+ data_object = self._dataset.attrs
835
+ elif len(split_result) == 2:
836
+ sub_slot = split_result[1]
837
+ data_object = self._dataset.attrs[split_result[1]]
838
+ data_object = pd.DataFrame([data_object])
839
+ self._slots[slot] = DataFrameCurator(data_object, slot_schema)
840
+ _assign_var_fields_categoricals_multimodal(
841
+ modality=table_key,
842
+ slot_type=sub_slot,
843
+ slot=slot,
844
+ slot_schema=slot_schema,
845
+ var_fields=self._var_fields,
846
+ cat_vectors=self._cat_vectors,
847
+ slots=self._slots,
848
+ )
849
+ self._columns_field = self._var_fields
850
+
851
+
852
+ class CatVector:
853
+ """Categorical vector for `DataFrame`.
854
+
855
+ Args:
856
+ values_getter: A callable or iterable that returns the values to validate.
857
+ field: The field to validate against.
858
+ key: The name of the column to validate. Only used for logging.
859
+ values_setter: A callable that sets the values.
860
+ source: The source to validate against.
861
+ """
862
+
863
+ def __init__(
864
+ self,
865
+ values_getter: Callable | Iterable[str],
866
+ field: FieldAttr,
867
+ key: str,
868
+ values_setter: Callable | None = None,
869
+ source: Record | None = None,
870
+ feature: Feature | None = None,
871
+ cat_manager: DataFrameCatManager | None = None,
872
+ subtype_str: str = "",
873
+ maximal_set: bool = False, # Passed during validation. Whether unvalidated categoricals cause validation failure.
874
+ ) -> None:
875
+ self._values_getter = values_getter
876
+ self._values_setter = values_setter
877
+ self._field = field
878
+ self._key = key
879
+ self._source = source
880
+ self._organism = None
881
+ self._validated: None | list[str] = None
882
+ self._non_validated: None | list[str] = None
883
+ self._synonyms: None | dict[str, str] = None
884
+ self._subtype_str = subtype_str
885
+ self._subtype_query_set = None
886
+ self._cat_manager = cat_manager
887
+ self.feature = feature
888
+ self.records = None
889
+ self._maximal_set = maximal_set
890
+ if hasattr(field.field.model, "_name_field"):
891
+ label_ref_is_name = field.field.name == field.field.model._name_field
892
+ else:
893
+ label_ref_is_name = field.field.name == "name"
894
+ self.label_ref_is_name = label_ref_is_name
895
+
896
+ @property
897
+ def values(self):
898
+ """Get the current values using the getter function."""
899
+ if callable(self._values_getter):
900
+ return self._values_getter()
901
+ return self._values_getter
902
+
903
+ @values.setter
904
+ def values(self, new_values):
905
+ """Set new values using the setter function if available."""
906
+ if callable(self._values_setter):
907
+ self._values_setter(new_values)
908
+ else:
909
+ # If values_getter is not callable, it's a direct reference we can update
910
+ self._values_getter = new_values
911
+
912
+ @property
913
+ def is_validated(self) -> bool:
914
+ """Whether the vector is validated."""
915
+ # ensembl gene IDs pass even if they were not validated
916
+ # this is a simple solution to the ensembl gene version problem
917
+ if self._field.field.attname == "ensembl_gene_id":
918
+ # if none of the ensembl gene ids were validated, we are probably not looking at ensembl gene IDs
919
+ if len(self.values) == len(self._non_validated):
920
+ return False
921
+ # if maximal set, we do not allow additional unvalidated genes
922
+ elif len(self._non_validated) != 0 and self._maximal_set:
923
+ return False
924
+ return True
925
+ else:
926
+ return len(self._non_validated) == 0
927
+
928
+ def _replace_synonyms(self) -> list[str]:
929
+ """Replace synonyms in the vector with standardized values."""
930
+ syn_mapper = self._synonyms
931
+ # replace the values in df
932
+ std_values = self.values.map(
933
+ lambda unstd_val: syn_mapper.get(unstd_val, unstd_val)
934
+ )
935
+ # remove the standardized values from self.non_validated
936
+ non_validated = [i for i in self._non_validated if i not in syn_mapper]
937
+ if len(non_validated) == 0:
938
+ self._non_validated = []
939
+ else:
940
+ self._non_validated = non_validated # type: ignore
941
+ # logging
942
+ n = len(syn_mapper)
943
+ if n > 0:
944
+ syn_mapper_print = _format_values(
945
+ [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
946
+ )
947
+ s = "s" if n > 1 else ""
948
+ logger.success(
949
+ f'standardized {n} synonym{s} in "{self._key}": {colors.green(syn_mapper_print)}'
950
+ )
951
+ return std_values
952
+
953
+ def __repr__(self) -> str:
954
+ if self._non_validated is None:
955
+ status = "unvalidated"
956
+ else:
957
+ status = (
958
+ "validated"
959
+ if len(self._non_validated) == 0
960
+ else f"non-validated ({len(self._non_validated)})"
961
+ )
962
+
963
+ field_name = getattr(self._field, "name", str(self._field))
964
+ values_count = len(self.values) if hasattr(self.values, "__len__") else "?"
965
+ return f"CatVector(key='{self._key}', field='{field_name}', values={values_count}, {status})"
966
+
967
+ def _add_validated(self) -> tuple[list, list]:
968
+ """Save features or labels records in the default instance."""
969
+ from lamindb.models.save import save as ln_save
970
+
971
+ registry = self._field.field.model
972
+ field_name = self._field.field.name
973
+ model_field = registry.__get_name_with_module__()
974
+ filter_kwargs = get_current_filter_kwargs(
975
+ registry, {"organism": self._organism, "source": self._source}
976
+ )
977
+ values = [i for i in self.values if isinstance(i, str) and i]
978
+ if not values:
979
+ return [], []
980
+ # inspect the default instance and save validated records from public
981
+ if (
982
+ self._subtype_str != "" and "__" not in self._subtype_str
983
+ ): # not for general filter expressions
984
+ self._subtype_query_set = registry.get(name=self._subtype_str).records.all()
985
+ values_array = np.array(values)
986
+ validated_mask = self._subtype_query_set.validate( # type: ignore
987
+ values_array, field=self._field, **filter_kwargs, mute=True
988
+ )
989
+ validated_labels, non_validated_labels = (
990
+ values_array[validated_mask],
991
+ values_array[~validated_mask],
992
+ )
993
+ records = registry.from_values(
994
+ validated_labels, field=self._field, **filter_kwargs, mute=True
995
+ )
996
+ else:
997
+ existing_and_public_records = registry.from_values(
998
+ list(values), field=self._field, **filter_kwargs, mute=True
999
+ )
1000
+ existing_and_public_labels = [
1001
+ getattr(r, field_name) for r in existing_and_public_records
1002
+ ]
1003
+ # public records that are not already in the database
1004
+ public_records = [r for r in existing_and_public_records if r._state.adding]
1005
+ # here we check to only save the public records if they are from the specified source
1006
+ # we check the uid because r.source and source can be from different instances
1007
+ if self._source:
1008
+ public_records = [
1009
+ r for r in public_records if r.source.uid == self._source.uid
1010
+ ]
1011
+ if len(public_records) > 0:
1012
+ logger.info(f"saving validated records of '{self._key}'")
1013
+ ln_save(public_records)
1014
+ labels_saved_public = [getattr(r, field_name) for r in public_records]
1015
+ # log the saved public labels
1016
+ # the term "transferred" stresses that this is always in the context of transferring
1017
+ # labels from a public ontology or a different instance to the present instance
1018
+ if len(labels_saved_public) > 0:
1019
+ s = "s" if len(labels_saved_public) > 1 else ""
1020
+ logger.success(
1021
+ f'added {len(labels_saved_public)} record{s} {colors.green("from_public")} with {model_field} for "{self._key}": {_format_values(labels_saved_public)}'
1022
+ )
1023
+ # non-validated records from the default instance
1024
+ non_validated_labels = [
1025
+ i for i in values if i not in existing_and_public_labels
1026
+ ]
1027
+ validated_labels = existing_and_public_labels
1028
+ records = existing_and_public_records
1029
+
1030
+ self.records = records
1031
+ # validated, non-validated
1032
+ return validated_labels, non_validated_labels
1033
+
1034
+ def _add_new(
1035
+ self,
1036
+ values: list[str],
1037
+ df: pd.DataFrame | None = None, # remove when all users use schema
1038
+ dtype: str | None = None,
1039
+ **create_kwargs,
1040
+ ) -> None:
1041
+ """Add new labels to the registry."""
1042
+ from lamindb.models.save import save as ln_save
1043
+
1044
+ registry = self._field.field.model
1045
+ field_name = self._field.field.name
1046
+ non_validated_records: RecordList[Any] = [] # type: ignore
1047
+ if df is not None and registry == Feature:
1048
+ nonval_columns = Feature.inspect(df.columns, mute=True).non_validated
1049
+ non_validated_records = Feature.from_df(df.loc[:, nonval_columns])
1050
+ else:
1051
+ if (
1052
+ self._organism
1053
+ and hasattr(registry, "organism")
1054
+ and registry._meta.get_field("organism").is_relation
1055
+ ):
1056
+ # make sure organism record is saved to the current instance
1057
+ create_kwargs["organism"] = _save_organism(name=self._organism)
1058
+
1059
+ for value in values:
1060
+ init_kwargs = {field_name: value}
1061
+ if registry == Feature:
1062
+ init_kwargs["dtype"] = "cat" if dtype is None else dtype
1063
+ non_validated_records.append(registry(**init_kwargs, **create_kwargs))
1064
+ if len(non_validated_records) > 0:
1065
+ ln_save(non_validated_records)
1066
+ model_field = colors.italic(registry.__get_name_with_module__())
1067
+ s = "s" if len(values) > 1 else ""
1068
+ logger.success(
1069
+ f'added {len(values)} record{s} with {model_field} for "{self._key}": {_format_values(values)}'
1070
+ )
1071
+
1072
+ def _validate(
1073
+ self,
1074
+ values: list[str],
1075
+ ) -> tuple[list[str], dict]:
1076
+ """Validate ontology terms using LaminDB registries."""
1077
+ registry = self._field.field.model
1078
+ field_name = self._field.field.name
1079
+ model_field = f"{registry.__name__}.{field_name}"
1080
+
1081
+ def _log_mapping_info():
1082
+ logger.indent = ""
1083
+ logger.info(f'mapping "{self._key}" on {colors.italic(model_field)}')
1084
+ logger.indent = " "
1085
+
1086
+ kwargs_current = get_current_filter_kwargs(
1087
+ registry, {"organism": self._organism, "source": self._source}
1088
+ )
1089
+
1090
+ # inspect values from the default instance, excluding public
1091
+ registry_or_queryset = registry
1092
+ if self._subtype_query_set is not None:
1093
+ registry_or_queryset = self._subtype_query_set
1094
+ inspect_result = registry_or_queryset.inspect(
1095
+ values, field=self._field, mute=True, from_source=False, **kwargs_current
1096
+ )
1097
+ non_validated = inspect_result.non_validated
1098
+ syn_mapper = inspect_result.synonyms_mapper
1099
+
1100
+ # inspect the non-validated values from public (BioRecord only)
1101
+ values_validated = []
1102
+ if hasattr(registry, "public"):
1103
+ public_records = registry.from_values(
1104
+ non_validated,
1105
+ field=self._field,
1106
+ mute=True,
1107
+ **kwargs_current,
1108
+ )
1109
+ values_validated += [getattr(r, field_name) for r in public_records]
1110
+
1111
+ # logging messages
1112
+ if self._cat_manager is not None:
1113
+ slot = self._cat_manager._slot
1114
+ else:
1115
+ slot = None
1116
+ in_slot = f" in slot '{slot}'" if slot is not None else ""
1117
+ slot_prefix = f".slots['{slot}']" if slot is not None else ""
1118
+ non_validated_hint_print = (
1119
+ f"curator{slot_prefix}.cat.add_new_from('{self._key}')"
1120
+ )
1121
+ non_validated = [i for i in non_validated if i not in values_validated]
1122
+ n_non_validated = len(non_validated)
1123
+ if n_non_validated == 0:
1124
+ logger.indent = ""
1125
+ logger.success(
1126
+ f'"{self._key}" is validated against {colors.italic(model_field)}'
1127
+ )
1128
+ return [], {}
1129
+ else:
1130
+ s = "" if n_non_validated == 1 else "s"
1131
+ print_values = _format_values(non_validated)
1132
+ warning_message = f"{colors.red(f'{n_non_validated} term{s}')} not validated in feature '{self._key}'{in_slot}: {colors.red(print_values)}\n"
1133
+ if syn_mapper:
1134
+ s = "" if len(syn_mapper) == 1 else "s"
1135
+ syn_mapper_print = _format_values(
1136
+ [f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
1137
+ )
1138
+ hint_msg = f'.standardize("{self._key}")'
1139
+ warning_message += f" {colors.yellow(f'{len(syn_mapper)} synonym{s}')} found: {colors.yellow(syn_mapper_print)}\n → curate synonyms via: {colors.cyan(hint_msg)}"
1140
+ if n_non_validated > len(syn_mapper):
1141
+ if syn_mapper:
1142
+ warning_message += "\n for remaining terms:\n"
1143
+ warning_message += f" → fix typos, remove non-existent values, or save terms via: {colors.cyan(non_validated_hint_print)}"
1144
+ if self._subtype_query_set is not None:
1145
+ warning_message += f"\n → a valid label for subtype '{self._subtype_str}' has to be one of {self._subtype_query_set.list('name')}"
1146
+ if logger.indent == "":
1147
+ _log_mapping_info()
1148
+ logger.warning(warning_message)
1149
+ if self._cat_manager is not None:
1150
+ self._cat_manager._validate_category_error_messages = strip_ansi_codes(
1151
+ warning_message
1152
+ )
1153
+ logger.indent = ""
1154
+ return non_validated, syn_mapper
1155
+
1156
+ def validate(self) -> None:
1157
+ """Validate the vector."""
1158
+ # add source-validated values to the registry
1159
+ self._validated, self._non_validated = self._add_validated()
1160
+ self._non_validated, self._synonyms = self._validate(values=self._non_validated)
1161
+
1162
+ # always register new Features if they are columns
1163
+ if self._key == "columns" and self._field == Feature.name:
1164
+ self.add_new()
1165
+
1166
+ def standardize(self) -> None:
1167
+ """Standardize the vector."""
1168
+ registry = self._field.field.model
1169
+ if not hasattr(registry, "standardize"):
1170
+ return self.values
1171
+ if self._synonyms is None:
1172
+ self.validate()
1173
+ # get standardized values
1174
+ std_values = self._replace_synonyms()
1175
+ # update non_validated values
1176
+ self._non_validated = [
1177
+ i for i in self._non_validated if i not in self._synonyms.keys()
1178
+ ]
1179
+ # remove synonyms since they are now standardized
1180
+ self._synonyms = {}
1181
+ # update the values with the standardized values
1182
+ self.values = std_values
1183
+
1184
+ def add_new(self, **create_kwargs) -> None:
1185
+ """Add new values to the registry."""
1186
+ if self._non_validated is None:
1187
+ self.validate()
1188
+ if len(self._synonyms) > 0:
1189
+ # raise error because .standardize modifies the input dataset
1190
+ raise ValidationError(
1191
+ "Please run `.standardize()` before adding new values."
1192
+ )
1193
+ self._add_new(
1194
+ values=self._non_validated,
1195
+ **create_kwargs,
1196
+ )
1197
+ # remove the non_validated values since they are now registered
1198
+ self._non_validated = []
1199
+
1200
+
1201
+ class DataFrameCatManager:
1202
+ """Manage categoricals by updating registries.
1203
+
1204
+ This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
1205
+
1206
+ If you find non-validated values, you have two options:
1207
+
1208
+ - new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.add_new_from`
1209
+ - non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.non_validated` and addressed manually
1210
+ """
1211
+
1212
+ def __init__(
1213
+ self,
1214
+ df: pd.DataFrame | Artifact,
1215
+ columns_field: FieldAttr = Feature.name,
1216
+ columns_names: Iterable[str] | None = None,
1217
+ categoricals: list[Feature] | None = None,
1218
+ sources: dict[str, Record] | None = None,
1219
+ index: Feature | None = None,
1220
+ slot: str | None = None,
1221
+ schema_maximal_set: bool = False,
1222
+ ) -> None:
1223
+ self._non_validated = None
1224
+ self._index = index
1225
+ self._artifact: Artifact = None # pass the dataset as an artifact
1226
+ self._dataset: Any = df # pass the dataset as a UPathStr or data object
1227
+ if isinstance(self._dataset, Artifact):
1228
+ self._artifact = self._dataset
1229
+ self._dataset = self._dataset.load(is_run_input=False)
1230
+ self._is_validated: bool = False
1231
+ self._categoricals = categoricals or []
1232
+ self._non_validated = None
1233
+ self._sources = sources or {}
1234
+ self._columns_field = columns_field
1235
+ self._validate_category_error_messages: str = ""
1236
+ self._cat_vectors: dict[str, CatVector] = {}
1237
+ self._slot = slot
1238
+ self._maximal_set = schema_maximal_set
1239
+
1240
+ if columns_names is None:
1241
+ columns_names = []
1242
+ if columns_field == Feature.name:
1243
+ self._cat_vectors["columns"] = CatVector(
1244
+ values_getter=columns_names,
1245
+ field=columns_field,
1246
+ key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
1247
+ source=self._sources.get("columns"),
1248
+ cat_manager=self,
1249
+ maximal_set=self._maximal_set,
1250
+ )
1251
+ else:
1252
+ self._cat_vectors["columns"] = CatVector(
1253
+ values_getter=lambda: self._dataset.columns, # lambda ensures the inplace update
1254
+ values_setter=lambda new_values: setattr(
1255
+ self._dataset, "columns", pd.Index(new_values)
1256
+ ),
1257
+ field=columns_field,
1258
+ key="columns",
1259
+ source=self._sources.get("columns"),
1260
+ cat_manager=self,
1261
+ maximal_set=self._maximal_set,
1262
+ )
1263
+ for feature in self._categoricals:
1264
+ result = parse_dtype(feature.dtype)[
1265
+ 0
1266
+ ] # TODO: support composite dtypes for categoricals
1267
+ key = feature.name
1268
+ field = result["field"]
1269
+ subtype_str = result["subtype_str"]
1270
+ self._cat_vectors[key] = CatVector(
1271
+ values_getter=lambda k=key: self._dataset[
1272
+ k
1273
+ ], # Capture key as default argument
1274
+ values_setter=lambda new_values, k=key: self._dataset.__setitem__(
1275
+ k, new_values
1276
+ ),
1277
+ field=field,
1278
+ key=key,
1279
+ source=self._sources.get(key),
1280
+ feature=feature,
1281
+ cat_manager=self,
1282
+ subtype_str=subtype_str,
1283
+ maximal_set=self._maximal_set,
1284
+ )
1285
+ if index is not None and index.dtype.startswith("cat"):
1286
+ result = parse_dtype(index.dtype)[0]
1287
+ field = result["field"]
1288
+ key = "index"
1289
+ self._cat_vectors[key] = CatVector(
1290
+ values_getter=self._dataset.index,
1291
+ field=field,
1292
+ key=key,
1293
+ feature=index,
1294
+ cat_manager=self,
1295
+ maximal_set=self._maximal_set,
1296
+ )
1297
+
1298
+ @property
1299
+ def non_validated(self) -> dict[str, list[str]]:
1300
+ """Return the non-validated features and labels."""
1301
+ if self._non_validated is None:
1302
+ raise ValidationError("Please run validate() first!")
1303
+ return {
1304
+ key: cat_vector._non_validated
1305
+ for key, cat_vector in self._cat_vectors.items()
1306
+ if cat_vector._non_validated and key != "columns"
1307
+ }
1308
+
1309
+ @property
1310
+ def categoricals(self) -> list[Feature]:
1311
+ """The categorical features."""
1312
+ return self._categoricals
1313
+
1314
+ def lookup(self, public: bool = False) -> CatLookup:
1315
+ """Lookup categories.
1316
+
1317
+ Args:
1318
+ public: If "public", the lookup is performed on the public reference.
1319
+ """
1320
+ return CatLookup(
1321
+ categoricals=self._categoricals,
1322
+ slots={"columns": self._columns_field},
1323
+ public=public,
1324
+ sources=self._sources,
1325
+ )
1326
+
1327
+ def validate(self) -> bool:
1328
+ """Validate variables and categorical observations."""
1329
+ self._validate_category_error_messages = "" # reset the error messages
1330
+
1331
+ validated = True
1332
+ for key, cat_vector in self._cat_vectors.items():
1333
+ logger.info(f"validating column {key}")
1334
+ cat_vector.validate()
1335
+ validated &= cat_vector.is_validated
1336
+ self._is_validated = validated
1337
+ self._non_validated = {} # type: ignore
1338
+
1339
+ if self._index is not None:
1340
+ # cat_vector.validate() populates validated labels
1341
+ # the index should become part of the feature set corresponding to the dataframe
1342
+ if self._cat_vectors["columns"].records is not None:
1343
+ self._cat_vectors["columns"].records.insert(0, self._index) # type: ignore
1344
+ else:
1345
+ self._cat_vectors["columns"].records = [self._index] # type: ignore
1346
+
1347
+ return self._is_validated
1348
+
1349
+ def standardize(self, key: str) -> None:
1350
+ """Replace synonyms with standardized values.
1351
+
1352
+ Modifies the input dataset inplace.
1353
+
1354
+ Args:
1355
+ key: The key referencing the column in the DataFrame to standardize.
1356
+ """
1357
+ if self._artifact is not None:
1358
+ raise RuntimeError("can't mutate the dataset when an artifact is passed!")
1359
+
1360
+ if key == "all":
1361
+ logger.warning(
1362
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1363
+ )
1364
+ for k in self.non_validated.keys():
1365
+ self._cat_vectors[k].standardize()
1366
+ else:
1367
+ self._cat_vectors[key].standardize()
1368
+
1369
+ def add_new_from(self, key: str, **kwargs):
1370
+ """Add validated & new categories.
1371
+
1372
+ Args:
1373
+ key: The key referencing the slot in the DataFrame from which to draw terms.
1374
+ **kwargs: Additional keyword arguments to pass to create new records
1375
+ """
1376
+ if len(kwargs) > 0 and key == "all":
1377
+ raise ValueError("Cannot pass additional arguments to 'all' key!")
1378
+ if key == "all":
1379
+ logger.warning(
1380
+ "'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
1381
+ )
1382
+ for k in self.non_validated.keys():
1383
+ self._cat_vectors[k].add_new(**kwargs)
1384
+ else:
1385
+ self._cat_vectors[key].add_new(**kwargs)
1386
+
1387
+
1388
+ def get_current_filter_kwargs(registry: type[Record], kwargs: dict) -> dict:
1389
+ """Make sure the source and organism are saved in the same database as the registry."""
1390
+ db = registry.filter().db
1391
+ source = kwargs.get("source")
1392
+ organism = kwargs.get("organism")
1393
+ filter_kwargs = kwargs.copy()
1394
+
1395
+ if isinstance(organism, Record) and organism._state.db != "default":
1396
+ if db is None or db == "default":
1397
+ organism_default = copy.copy(organism)
1398
+ # save the organism record in the default database
1399
+ organism_default.save()
1400
+ filter_kwargs["organism"] = organism_default
1401
+ if isinstance(source, Record) and source._state.db != "default":
1402
+ if db is None or db == "default":
1403
+ source_default = copy.copy(source)
1404
+ # save the source record in the default database
1405
+ source_default.save()
1406
+ filter_kwargs["source"] = source_default
1407
+
1408
+ return filter_kwargs
1409
+
1410
+
1411
+ def get_organism_kwargs(
1412
+ field: FieldAttr, organism: str | None = None, values: Any = None
1413
+ ) -> dict[str, str]:
1414
+ """Check if a registry needs an organism and return the organism name."""
1415
+ registry = field.field.model
1416
+ if registry.__base__.__name__ == "BioRecord":
1417
+ import bionty as bt
1418
+ from bionty._organism import is_organism_required
1419
+
1420
+ from ..models._from_values import get_organism_record_from_field
1421
+
1422
+ if is_organism_required(registry):
1423
+ if organism is not None or bt.settings.organism is not None:
1424
+ return {"organism": organism or bt.settings.organism.name}
1425
+ else:
1426
+ organism_record = get_organism_record_from_field(
1427
+ field, organism=organism, values=values
1428
+ )
1429
+ if organism_record is not None:
1430
+ return {"organism": organism_record.name}
1431
+ return {}
1432
+
1433
+
1434
+ def annotate_artifact(
1435
+ artifact: Artifact,
1436
+ *,
1437
+ curator: AnnDataCurator | SlotsCurator | None = None,
1438
+ cat_vectors: dict[str, CatVector] | None = None,
1439
+ ) -> Artifact:
1440
+ from .. import settings
1441
+ from ..models.artifact import add_labels
1442
+
1443
+ if cat_vectors is None:
1444
+ cat_vectors = {}
1445
+
1446
+ # annotate with labels
1447
+ for key, cat_vector in cat_vectors.items():
1448
+ if (
1449
+ cat_vector._field.field.model == Feature
1450
+ or key == "columns"
1451
+ or key == "var_index"
1452
+ ):
1453
+ continue
1454
+ if len(cat_vector.records) > settings.annotation.n_max_records:
1455
+ logger.important(
1456
+ f"not annotating with {len(cat_vector.records)} labels for feature {key} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1457
+ )
1458
+ continue
1459
+ add_labels(
1460
+ artifact,
1461
+ records=cat_vector.records,
1462
+ feature=cat_vector.feature,
1463
+ feature_ref_is_name=None, # do not need anymore
1464
+ label_ref_is_name=cat_vector.label_ref_is_name,
1465
+ from_curator=True,
1466
+ )
1467
+
1468
+ # annotate with inferred schemas aka feature sets
1469
+ if artifact.otype == "DataFrame":
1470
+ features = cat_vectors["columns"].records
1471
+ if features is not None:
1472
+ feature_set = Schema(
1473
+ features=features, coerce_dtype=artifact.schema.coerce_dtype
1474
+ ) # TODO: add more defaults from validating schema
1475
+ if (
1476
+ feature_set._state.adding
1477
+ and len(features) > settings.annotation.n_max_records
1478
+ ):
1479
+ logger.important(
1480
+ f"not annotating with {len(features)} features as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1481
+ )
1482
+ itype = parse_cat_dtype(artifact.schema.itype, is_itype=True)["field"]
1483
+ feature_set = Schema(itype=itype, n=len(features))
1484
+ artifact.feature_sets.add(
1485
+ feature_set.save(), through_defaults={"slot": "columns"}
1486
+ )
1487
+ else:
1488
+ for slot, slot_curator in curator._slots.items():
1489
+ # var_index is backward compat (2025-05-01)
1490
+ name = (
1491
+ "var_index"
1492
+ if (slot == "var" and "var_index" in slot_curator.cat._cat_vectors)
1493
+ else "columns"
1494
+ )
1495
+ features = slot_curator.cat._cat_vectors[name].records
1496
+ itype = parse_cat_dtype(artifact.schema.slots[slot].itype, is_itype=True)[
1497
+ "field"
1498
+ ]
1499
+ feature_set = Schema(features=features, itype=itype)
1500
+ if (
1501
+ feature_set._state.adding
1502
+ and len(features) > settings.annotation.n_max_records
1503
+ ):
1504
+ logger.important(
1505
+ f"not annotating with {len(features)} features for slot {slot} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
1506
+ )
1507
+ feature_set = Schema(itype=itype, n=len(features))
1508
+ artifact.feature_sets.add(
1509
+ feature_set.save(), through_defaults={"slot": slot}
1510
+ )
1511
+
1512
+ slug = ln_setup.settings.instance.slug
1513
+ if ln_setup.settings.instance.is_remote: # pdagma: no cover
1514
+ logger.important(f"go to https://lamin.ai/{slug}/artifact/{artifact.uid}")
1515
+ return artifact
1516
+
1517
+
1518
+ # TODO: need this function to support mutli-value columns
1519
+ def _flatten_unique(series: pd.Series[list[Any] | Any]) -> list[Any]:
1520
+ """Flatten a Pandas series containing lists or single items into a unique list of elements."""
1521
+ result = set()
1522
+
1523
+ for item in series:
1524
+ if isinstance(item, list):
1525
+ result.update(item)
1526
+ else:
1527
+ result.add(item)
1528
+
1529
+ return list(result)
1530
+
1531
+
1532
+ def _save_organism(name: str):
1533
+ """Save an organism record."""
1534
+ import bionty as bt
1535
+
1536
+ organism = bt.Organism.filter(name=name).one_or_none()
1537
+ if organism is None:
1538
+ organism = bt.Organism.from_source(name=name)
1539
+ if organism is None:
1540
+ raise ValidationError(
1541
+ f'Organism "{name}" not found from public reference\n'
1542
+ f' → please save it from a different source: bt.Organism.from_source(name="{name}", source).save()'
1543
+ f' → or manually save it without source: bt.Organism(name="{name}").save()'
1544
+ )
1545
+ organism.save()
1546
+ return organism