lamindb 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lamindb/__init__.py +52 -36
- lamindb/_finish.py +17 -10
- lamindb/_tracked.py +1 -1
- lamindb/base/__init__.py +3 -1
- lamindb/base/fields.py +40 -22
- lamindb/base/ids.py +1 -94
- lamindb/base/types.py +2 -0
- lamindb/base/uids.py +117 -0
- lamindb/core/_context.py +203 -102
- lamindb/core/_settings.py +38 -25
- lamindb/core/datasets/__init__.py +11 -4
- lamindb/core/datasets/_core.py +5 -5
- lamindb/core/datasets/_small.py +0 -93
- lamindb/core/datasets/mini_immuno.py +172 -0
- lamindb/core/loaders.py +1 -1
- lamindb/core/storage/_backed_access.py +100 -6
- lamindb/core/storage/_polars_lazy_df.py +51 -0
- lamindb/core/storage/_pyarrow_dataset.py +15 -30
- lamindb/core/storage/_tiledbsoma.py +29 -13
- lamindb/core/storage/objects.py +6 -0
- lamindb/core/subsettings/__init__.py +2 -0
- lamindb/core/subsettings/_annotation_settings.py +11 -0
- lamindb/curators/__init__.py +7 -3349
- lamindb/curators/_legacy.py +2056 -0
- lamindb/curators/core.py +1534 -0
- lamindb/errors.py +11 -0
- lamindb/examples/__init__.py +27 -0
- lamindb/examples/schemas/__init__.py +12 -0
- lamindb/examples/schemas/_anndata.py +25 -0
- lamindb/examples/schemas/_simple.py +19 -0
- lamindb/integrations/_vitessce.py +8 -5
- lamindb/migrations/0091_alter_featurevalue_options_alter_space_options_and_more.py +24 -0
- lamindb/migrations/0092_alter_artifactfeaturevalue_artifact_and_more.py +75 -0
- lamindb/migrations/0093_alter_schemacomponent_unique_together.py +16 -0
- lamindb/models/__init__.py +4 -1
- lamindb/models/_describe.py +21 -4
- lamindb/models/_feature_manager.py +382 -287
- lamindb/models/_label_manager.py +8 -2
- lamindb/models/artifact.py +177 -106
- lamindb/models/artifact_set.py +122 -0
- lamindb/models/collection.py +73 -52
- lamindb/models/core.py +1 -1
- lamindb/models/feature.py +51 -17
- lamindb/models/has_parents.py +69 -14
- lamindb/models/project.py +1 -1
- lamindb/models/query_manager.py +221 -22
- lamindb/models/query_set.py +247 -172
- lamindb/models/record.py +65 -247
- lamindb/models/run.py +4 -4
- lamindb/models/save.py +8 -2
- lamindb/models/schema.py +456 -184
- lamindb/models/transform.py +2 -2
- lamindb/models/ulabel.py +8 -5
- {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/METADATA +6 -6
- {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/RECORD +57 -43
- {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/LICENSE +0 -0
- {lamindb-1.4.0.dist-info → lamindb-1.5.1.dist-info}/WHEEL +0 -0
lamindb/curators/core.py
ADDED
@@ -0,0 +1,1534 @@
|
|
1
|
+
"""Curator utilities.
|
2
|
+
|
3
|
+
.. autosummary::
|
4
|
+
:toctree: .
|
5
|
+
|
6
|
+
Curator
|
7
|
+
SlotsCurator
|
8
|
+
CatVector
|
9
|
+
CatLookup
|
10
|
+
DataFrameCatManager
|
11
|
+
|
12
|
+
"""
|
13
|
+
|
14
|
+
from __future__ import annotations
|
15
|
+
|
16
|
+
import copy
|
17
|
+
import re
|
18
|
+
from typing import TYPE_CHECKING, Any, Callable
|
19
|
+
|
20
|
+
import lamindb_setup as ln_setup
|
21
|
+
import numpy as np
|
22
|
+
import pandas as pd
|
23
|
+
import pandera
|
24
|
+
from lamin_utils import colors, logger
|
25
|
+
from lamindb_setup.core._docs import doc_args
|
26
|
+
|
27
|
+
from lamindb.base.types import FieldAttr # noqa
|
28
|
+
from lamindb.models import (
|
29
|
+
Artifact,
|
30
|
+
Feature,
|
31
|
+
Record,
|
32
|
+
Run,
|
33
|
+
Schema,
|
34
|
+
)
|
35
|
+
from lamindb.models._from_values import _format_values
|
36
|
+
from lamindb.models.artifact import (
|
37
|
+
data_is_anndata,
|
38
|
+
data_is_mudata,
|
39
|
+
data_is_spatialdata,
|
40
|
+
)
|
41
|
+
from lamindb.models.feature import parse_cat_dtype, parse_dtype
|
42
|
+
|
43
|
+
from ..errors import InvalidArgument, ValidationError
|
44
|
+
|
45
|
+
if TYPE_CHECKING:
|
46
|
+
from collections.abc import Iterable
|
47
|
+
from typing import Any
|
48
|
+
|
49
|
+
from anndata import AnnData
|
50
|
+
from mudata import MuData
|
51
|
+
from spatialdata import SpatialData
|
52
|
+
|
53
|
+
from lamindb.models.query_set import RecordList
|
54
|
+
|
55
|
+
|
56
|
+
def strip_ansi_codes(text):
|
57
|
+
# This pattern matches ANSI escape sequences
|
58
|
+
ansi_pattern = re.compile(r"\x1b\[[0-9;]*m")
|
59
|
+
return ansi_pattern.sub("", text)
|
60
|
+
|
61
|
+
|
62
|
+
class CatLookup:
|
63
|
+
"""Lookup categories from the reference instance.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
categoricals: A dictionary of categorical fields to lookup.
|
67
|
+
slots: A dictionary of slot fields to lookup.
|
68
|
+
public: Whether to lookup from the public instance. Defaults to False.
|
69
|
+
|
70
|
+
Example::
|
71
|
+
|
72
|
+
curator = ln.curators.DataFrameCurator(...)
|
73
|
+
curator.cat.lookup()["cell_type"].alveolar_type_1_fibroblast_cell
|
74
|
+
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
categoricals: list[Feature] | dict[str, FieldAttr],
|
80
|
+
slots: dict[str, FieldAttr] = None,
|
81
|
+
public: bool = False,
|
82
|
+
sources: dict[str, Record] | None = None,
|
83
|
+
) -> None:
|
84
|
+
slots = slots or {}
|
85
|
+
if isinstance(categoricals, list):
|
86
|
+
categoricals = {
|
87
|
+
feature.name: parse_dtype(feature.dtype)[0]["field"]
|
88
|
+
for feature in categoricals
|
89
|
+
}
|
90
|
+
self._categoricals = {**categoricals, **slots}
|
91
|
+
self._public = public
|
92
|
+
self._sources = sources
|
93
|
+
|
94
|
+
def __getattr__(self, name):
|
95
|
+
if name in self._categoricals:
|
96
|
+
registry = self._categoricals[name].field.model
|
97
|
+
if self._public and hasattr(registry, "public"):
|
98
|
+
return registry.public(source=self._sources.get(name)).lookup()
|
99
|
+
else:
|
100
|
+
return registry.lookup()
|
101
|
+
raise AttributeError(
|
102
|
+
f'"{self.__class__.__name__}" object has no attribute "{name}"'
|
103
|
+
)
|
104
|
+
|
105
|
+
def __getitem__(self, name):
|
106
|
+
if name in self._categoricals:
|
107
|
+
registry = self._categoricals[name].field.model
|
108
|
+
if self._public and hasattr(registry, "public"):
|
109
|
+
return registry.public(source=self._sources.get(name)).lookup()
|
110
|
+
else:
|
111
|
+
return registry.lookup()
|
112
|
+
raise AttributeError(
|
113
|
+
f'"{self.__class__.__name__}" object has no attribute "{name}"'
|
114
|
+
)
|
115
|
+
|
116
|
+
def __repr__(self) -> str:
|
117
|
+
if len(self._categoricals) > 0:
|
118
|
+
getattr_keys = "\n ".join(
|
119
|
+
[f".{key}" for key in self._categoricals if key.isidentifier()]
|
120
|
+
)
|
121
|
+
getitem_keys = "\n ".join(
|
122
|
+
[str([key]) for key in self._categoricals if not key.isidentifier()]
|
123
|
+
)
|
124
|
+
ref = "public" if self._public else "registries"
|
125
|
+
return (
|
126
|
+
f"Lookup objects from the {colors.italic(ref)}:\n "
|
127
|
+
f"{colors.green(getattr_keys)}\n "
|
128
|
+
f"{colors.green(getitem_keys)}\n"
|
129
|
+
'Example:\n → categories = curator.lookup()["cell_type"]\n'
|
130
|
+
" → categories.alveolar_type_1_fibroblast_cell\n\n"
|
131
|
+
"To look up public ontologies, use .lookup(public=True)"
|
132
|
+
)
|
133
|
+
else: # pragma: no cover
|
134
|
+
return colors.warning("No fields are found!")
|
135
|
+
|
136
|
+
|
137
|
+
CAT_MANAGER_DOCSTRING = """Manage categoricals by updating registries."""
|
138
|
+
|
139
|
+
|
140
|
+
SLOTS_DOCSTRING = """Access sub curators by slot."""
|
141
|
+
|
142
|
+
|
143
|
+
VALIDATE_DOCSTRING = """Validate dataset against Schema.
|
144
|
+
|
145
|
+
Raises:
|
146
|
+
lamindb.errors.ValidationError: If validation fails.
|
147
|
+
"""
|
148
|
+
|
149
|
+
SAVE_ARTIFACT_DOCSTRING = """Save an annotated artifact.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
key: A path-like key to reference artifact in default storage, e.g., `"myfolder/myfile.fcs"`. Artifacts with the same key form a version family.
|
153
|
+
description: A description.
|
154
|
+
revises: Previous version of the artifact. Is an alternative way to passing `key` to trigger a new version.
|
155
|
+
run: The run that creates the artifact.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
A saved artifact record.
|
159
|
+
"""
|
160
|
+
|
161
|
+
|
162
|
+
class Curator:
|
163
|
+
"""Curator base class.
|
164
|
+
|
165
|
+
A `Curator` object makes it easy to validate, standardize & annotate datasets.
|
166
|
+
|
167
|
+
See:
|
168
|
+
- :class:`~lamindb.curators.DataFrameCurator`
|
169
|
+
- :class:`~lamindb.curators.AnnDataCurator`
|
170
|
+
- :class:`~lamindb.curators.MuDataCurator`
|
171
|
+
- :class:`~lamindb.curators.SpatialDataCurator`
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(self, dataset: Any, schema: Schema | None = None):
|
175
|
+
self._artifact: Artifact = None # pass the dataset as an artifact
|
176
|
+
self._dataset: Any = dataset # pass the dataset as a UPathStr or data object
|
177
|
+
if isinstance(self._dataset, Artifact):
|
178
|
+
self._artifact = self._dataset
|
179
|
+
if self._artifact.otype in {
|
180
|
+
"DataFrame",
|
181
|
+
"AnnData",
|
182
|
+
"MuData",
|
183
|
+
"SpatialData",
|
184
|
+
}:
|
185
|
+
self._dataset = self._dataset.load(is_run_input=False)
|
186
|
+
self._schema: Schema | None = schema
|
187
|
+
self._is_validated: bool = False
|
188
|
+
|
189
|
+
@doc_args(VALIDATE_DOCSTRING)
|
190
|
+
def validate(self) -> bool | str:
|
191
|
+
"""{}""" # noqa: D415
|
192
|
+
pass # pragma: no cover
|
193
|
+
|
194
|
+
@doc_args(SAVE_ARTIFACT_DOCSTRING)
|
195
|
+
def save_artifact(
|
196
|
+
self,
|
197
|
+
*,
|
198
|
+
key: str | None = None,
|
199
|
+
description: str | None = None,
|
200
|
+
revises: Artifact | None = None,
|
201
|
+
run: Run | None = None,
|
202
|
+
) -> Artifact:
|
203
|
+
"""{}""" # noqa: D415
|
204
|
+
# Note that this docstring has to be consistent with the Artifact()
|
205
|
+
# constructor signature
|
206
|
+
pass # pragma: no cover
|
207
|
+
|
208
|
+
def __repr__(self) -> str:
|
209
|
+
from lamin_utils import colors
|
210
|
+
|
211
|
+
if self._schema is not None:
|
212
|
+
# Schema might have different attributes
|
213
|
+
if hasattr(self._schema, "name") and self._schema.name:
|
214
|
+
schema_str = colors.italic(self._schema.name)
|
215
|
+
elif hasattr(self._schema, "uid"):
|
216
|
+
schema_str = colors.italic(f"uid={self._schema.uid}")
|
217
|
+
elif hasattr(self._schema, "id"):
|
218
|
+
schema_str = colors.italic(f"id={self._schema.id}")
|
219
|
+
else:
|
220
|
+
schema_str = colors.italic("unnamed")
|
221
|
+
|
222
|
+
# Add schema type info if available
|
223
|
+
if hasattr(self._schema, "otype") and self._schema.otype:
|
224
|
+
schema_str += f" ({self._schema.otype})"
|
225
|
+
else:
|
226
|
+
schema_str = colors.warning("None")
|
227
|
+
|
228
|
+
status_str = ""
|
229
|
+
if self._is_validated:
|
230
|
+
status_str = f", {colors.green('validated')}"
|
231
|
+
else:
|
232
|
+
status_str = f", {colors.yellow('unvalidated')}"
|
233
|
+
|
234
|
+
cls_name = colors.green(self.__class__.__name__)
|
235
|
+
|
236
|
+
# Get additional info based on curator type
|
237
|
+
extra_info = ""
|
238
|
+
if hasattr(self, "_slots") and self._slots:
|
239
|
+
# For SlotsCurator and its subclasses
|
240
|
+
slots_count = len(self._slots)
|
241
|
+
if slots_count > 0:
|
242
|
+
slot_names = list(self._slots.keys())
|
243
|
+
if len(slot_names) <= 3:
|
244
|
+
extra_info = f", slots: {slot_names}"
|
245
|
+
else:
|
246
|
+
extra_info = f", slots: [{', '.join(slot_names[:3])}... +{len(slot_names) - 3} more]"
|
247
|
+
elif (
|
248
|
+
cls_name == "DataFrameCurator"
|
249
|
+
and hasattr(self, "cat")
|
250
|
+
and hasattr(self.cat, "_categoricals")
|
251
|
+
):
|
252
|
+
# For DataFrameCurator
|
253
|
+
cat_count = len(getattr(self.cat, "_categoricals", []))
|
254
|
+
if cat_count > 0:
|
255
|
+
extra_info = f", categorical_features={cat_count}"
|
256
|
+
|
257
|
+
artifact_info = ""
|
258
|
+
if self._artifact is not None:
|
259
|
+
artifact_uid = getattr(self._artifact, "uid", str(self._artifact))
|
260
|
+
short_uid = (
|
261
|
+
str(artifact_uid)[:8] + "..."
|
262
|
+
if len(str(artifact_uid)) > 8
|
263
|
+
else str(artifact_uid)
|
264
|
+
)
|
265
|
+
artifact_info = f", artifact: {colors.italic(short_uid)}"
|
266
|
+
|
267
|
+
return (
|
268
|
+
f"{cls_name}{artifact_info}(Schema: {schema_str}{extra_info}{status_str})"
|
269
|
+
)
|
270
|
+
|
271
|
+
|
272
|
+
# default implementation for AnnDataCurator, MuDataCurator, and SpatialDataCurator
|
273
|
+
class SlotsCurator(Curator):
|
274
|
+
"""Curator for a dataset with slots.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
dataset: The dataset to validate & annotate.
|
278
|
+
schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
|
279
|
+
|
280
|
+
"""
|
281
|
+
|
282
|
+
def __init__(
|
283
|
+
self,
|
284
|
+
dataset: Any,
|
285
|
+
schema: Schema,
|
286
|
+
) -> None:
|
287
|
+
super().__init__(dataset=dataset, schema=schema)
|
288
|
+
self._slots: dict[str, DataFrameCurator] = {}
|
289
|
+
|
290
|
+
# used in MuDataCurator and SpatialDataCurator
|
291
|
+
# in form of {table/modality_key: var_field}
|
292
|
+
self._var_fields: dict[str, FieldAttr] = {}
|
293
|
+
# in form of {table/modality_key: categoricals}
|
294
|
+
self._cat_vectors: dict[str, dict[str, CatVector]] = {}
|
295
|
+
|
296
|
+
@property
|
297
|
+
@doc_args(SLOTS_DOCSTRING)
|
298
|
+
def slots(self) -> dict[str, DataFrameCurator]:
|
299
|
+
"""{}""" # noqa: D415
|
300
|
+
return self._slots
|
301
|
+
|
302
|
+
@doc_args(VALIDATE_DOCSTRING)
|
303
|
+
def validate(self) -> None:
|
304
|
+
"""{}""" # noqa: D415
|
305
|
+
for slot, curator in self._slots.items():
|
306
|
+
logger.info(f"validating slot {slot} ...")
|
307
|
+
curator.validate()
|
308
|
+
|
309
|
+
@doc_args(SAVE_ARTIFACT_DOCSTRING)
|
310
|
+
def save_artifact(
|
311
|
+
self,
|
312
|
+
*,
|
313
|
+
key: str | None = None,
|
314
|
+
description: str | None = None,
|
315
|
+
revises: Artifact | None = None,
|
316
|
+
run: Run | None = None,
|
317
|
+
) -> Artifact:
|
318
|
+
"""{}""" # noqa: D415
|
319
|
+
if not self._is_validated:
|
320
|
+
self.validate()
|
321
|
+
if self._artifact is None:
|
322
|
+
if data_is_anndata(self._dataset):
|
323
|
+
self._artifact = Artifact.from_anndata(
|
324
|
+
self._dataset,
|
325
|
+
key=key,
|
326
|
+
description=description,
|
327
|
+
revises=revises,
|
328
|
+
run=run,
|
329
|
+
)
|
330
|
+
if data_is_mudata(self._dataset):
|
331
|
+
self._artifact = Artifact.from_mudata(
|
332
|
+
self._dataset,
|
333
|
+
key=key,
|
334
|
+
description=description,
|
335
|
+
revises=revises,
|
336
|
+
run=run,
|
337
|
+
)
|
338
|
+
elif data_is_spatialdata(self._dataset):
|
339
|
+
self._artifact = Artifact.from_spatialdata(
|
340
|
+
self._dataset,
|
341
|
+
key=key,
|
342
|
+
description=description,
|
343
|
+
revises=revises,
|
344
|
+
run=run,
|
345
|
+
)
|
346
|
+
self._artifact.schema = self._schema
|
347
|
+
self._artifact.save()
|
348
|
+
cat_vectors = {}
|
349
|
+
for curator in self._slots.values():
|
350
|
+
for key, cat_vector in curator.cat._cat_vectors.items():
|
351
|
+
cat_vectors[key] = cat_vector
|
352
|
+
return annotate_artifact( # type: ignore
|
353
|
+
self._artifact,
|
354
|
+
curator=self,
|
355
|
+
cat_vectors=cat_vectors,
|
356
|
+
)
|
357
|
+
|
358
|
+
|
359
|
+
def check_dtype(expected_type) -> Callable:
|
360
|
+
"""Creates a check function for Pandera that validates a column's dtype.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
expected_type: String identifier for the expected type ('int', 'float', or 'num')
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
A function that checks if a series has the expected dtype
|
367
|
+
"""
|
368
|
+
|
369
|
+
def check_function(series):
|
370
|
+
if expected_type == "int":
|
371
|
+
is_valid = pd.api.types.is_integer_dtype(series.dtype)
|
372
|
+
elif expected_type == "float":
|
373
|
+
is_valid = pd.api.types.is_float_dtype(series.dtype)
|
374
|
+
elif expected_type == "num":
|
375
|
+
is_valid = pd.api.types.is_numeric_dtype(series.dtype)
|
376
|
+
return is_valid
|
377
|
+
|
378
|
+
return check_function
|
379
|
+
|
380
|
+
|
381
|
+
# this is also currently used as DictCurator
|
382
|
+
class DataFrameCurator(Curator):
|
383
|
+
# the example in the docstring is tested in test_curators_quickstart_example
|
384
|
+
"""Curator for `DataFrame`.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
dataset: The DataFrame-like object to validate & annotate.
|
388
|
+
schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
|
389
|
+
slot: Indicate the slot in a composite curator for a composite data structure.
|
390
|
+
|
391
|
+
Example:
|
392
|
+
|
393
|
+
For simple example using a flexible schema, see :meth:`~lamindb.Artifact.from_df`.
|
394
|
+
|
395
|
+
Here is an example that enforces a minimal set of columns in the dataframe.
|
396
|
+
|
397
|
+
.. literalinclude:: scripts/curate_dataframe_minimal_errors.py
|
398
|
+
:language: python
|
399
|
+
|
400
|
+
Under-the-hood, this used the following schema.
|
401
|
+
|
402
|
+
.. literalinclude:: scripts/define_mini_immuno_schema_flexible.py
|
403
|
+
:language: python
|
404
|
+
|
405
|
+
Valid features & labels were defined as:
|
406
|
+
|
407
|
+
.. literalinclude:: scripts/define_mini_immuno_features_labels.py
|
408
|
+
:language: python
|
409
|
+
"""
|
410
|
+
|
411
|
+
def __init__(
|
412
|
+
self,
|
413
|
+
dataset: pd.DataFrame | Artifact,
|
414
|
+
schema: Schema,
|
415
|
+
slot: str | None = None,
|
416
|
+
) -> None:
|
417
|
+
super().__init__(dataset=dataset, schema=schema)
|
418
|
+
categoricals = []
|
419
|
+
features = []
|
420
|
+
feature_ids: set[int] = set()
|
421
|
+
if schema.flexible:
|
422
|
+
features += Feature.filter(name__in=self._dataset.keys()).list()
|
423
|
+
feature_ids = {feature.id for feature in features}
|
424
|
+
if schema.n > 0:
|
425
|
+
if schema._index_feature_uid is not None:
|
426
|
+
schema_features = [
|
427
|
+
feature
|
428
|
+
for feature in schema.members.list()
|
429
|
+
if feature.uid != schema._index_feature_uid # type: ignore
|
430
|
+
]
|
431
|
+
else:
|
432
|
+
schema_features = schema.members.list() # type: ignore
|
433
|
+
if feature_ids:
|
434
|
+
features.extend(
|
435
|
+
feature
|
436
|
+
for feature in schema_features
|
437
|
+
if feature.id not in feature_ids # type: ignore
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
features.extend(schema_features)
|
441
|
+
else:
|
442
|
+
assert schema.itype is not None # noqa: S101
|
443
|
+
pandera_columns = {}
|
444
|
+
if features or schema._index_feature_uid is not None:
|
445
|
+
# populate features
|
446
|
+
if schema.minimal_set:
|
447
|
+
optional_feature_uids = set(schema.optionals.get_uids())
|
448
|
+
for feature in features:
|
449
|
+
if schema.minimal_set:
|
450
|
+
required = feature.uid not in optional_feature_uids
|
451
|
+
else:
|
452
|
+
required = False
|
453
|
+
if feature.dtype in {"int", "float", "num"}:
|
454
|
+
if isinstance(self._dataset, pd.DataFrame):
|
455
|
+
dtype = (
|
456
|
+
self._dataset[feature.name].dtype
|
457
|
+
if feature.name in self._dataset.keys()
|
458
|
+
else None
|
459
|
+
)
|
460
|
+
else:
|
461
|
+
dtype = None
|
462
|
+
pandera_columns[feature.name] = pandera.Column(
|
463
|
+
dtype=None,
|
464
|
+
checks=pandera.Check(
|
465
|
+
check_dtype(feature.dtype),
|
466
|
+
element_wise=False,
|
467
|
+
error=f"Column '{feature.name}' failed dtype check for '{feature.dtype}': got {dtype}",
|
468
|
+
),
|
469
|
+
nullable=feature.nullable,
|
470
|
+
coerce=feature.coerce_dtype,
|
471
|
+
required=required,
|
472
|
+
)
|
473
|
+
else:
|
474
|
+
pandera_dtype = (
|
475
|
+
feature.dtype
|
476
|
+
if not feature.dtype.startswith("cat")
|
477
|
+
else "category"
|
478
|
+
)
|
479
|
+
pandera_columns[feature.name] = pandera.Column(
|
480
|
+
pandera_dtype,
|
481
|
+
nullable=feature.nullable,
|
482
|
+
coerce=feature.coerce_dtype,
|
483
|
+
required=required,
|
484
|
+
)
|
485
|
+
if feature.dtype.startswith("cat"):
|
486
|
+
# validate categoricals if the column is required or if the column is present
|
487
|
+
if required or feature.name in self._dataset.keys():
|
488
|
+
categoricals.append(feature)
|
489
|
+
if schema._index_feature_uid is not None:
|
490
|
+
# in almost no case, an index should have a pandas.CategoricalDtype in a DataFrame
|
491
|
+
# so, we're typing it as `str` here
|
492
|
+
index = pandera.Index(
|
493
|
+
schema.index.dtype
|
494
|
+
if not schema.index.dtype.startswith("cat")
|
495
|
+
else str
|
496
|
+
)
|
497
|
+
else:
|
498
|
+
index = None
|
499
|
+
self._pandera_schema = pandera.DataFrameSchema(
|
500
|
+
pandera_columns,
|
501
|
+
coerce=schema.coerce_dtype,
|
502
|
+
strict=schema.maximal_set,
|
503
|
+
ordered=schema.ordered_set,
|
504
|
+
index=index,
|
505
|
+
)
|
506
|
+
self._cat_manager = DataFrameCatManager(
|
507
|
+
self._dataset,
|
508
|
+
columns_field=parse_cat_dtype(schema.itype, is_itype=True)["field"],
|
509
|
+
columns_names=pandera_columns.keys(),
|
510
|
+
categoricals=categoricals,
|
511
|
+
index=schema.index,
|
512
|
+
slot=slot,
|
513
|
+
maximal_set=schema.maximal_set,
|
514
|
+
)
|
515
|
+
|
516
|
+
@property
|
517
|
+
@doc_args(CAT_MANAGER_DOCSTRING)
|
518
|
+
def cat(self) -> DataFrameCatManager:
|
519
|
+
"""{}""" # noqa: D415
|
520
|
+
return self._cat_manager
|
521
|
+
|
522
|
+
def standardize(self) -> None:
|
523
|
+
"""Standardize the dataset.
|
524
|
+
|
525
|
+
- Adds missing columns for features
|
526
|
+
- Fills missing values for features with default values
|
527
|
+
"""
|
528
|
+
for feature in self._schema.members:
|
529
|
+
if feature.name not in self._dataset.columns:
|
530
|
+
if feature.default_value is not None or feature.nullable:
|
531
|
+
fill_value = (
|
532
|
+
feature.default_value
|
533
|
+
if feature.default_value is not None
|
534
|
+
else pd.NA
|
535
|
+
)
|
536
|
+
if feature.dtype.startswith("cat"):
|
537
|
+
self._dataset[feature.name] = pd.Categorical(
|
538
|
+
[fill_value] * len(self._dataset)
|
539
|
+
)
|
540
|
+
else:
|
541
|
+
self._dataset[feature.name] = fill_value
|
542
|
+
logger.important(
|
543
|
+
f"added column {feature.name} with fill value {fill_value}"
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
raise ValidationError(
|
547
|
+
f"Missing column {feature.name} cannot be added because is not nullable and has no default value"
|
548
|
+
)
|
549
|
+
else:
|
550
|
+
if feature.default_value is not None:
|
551
|
+
if isinstance(
|
552
|
+
self._dataset[feature.name].dtype, pd.CategoricalDtype
|
553
|
+
):
|
554
|
+
if (
|
555
|
+
feature.default_value
|
556
|
+
not in self._dataset[feature.name].cat.categories
|
557
|
+
):
|
558
|
+
self._dataset[feature.name] = self._dataset[
|
559
|
+
feature.name
|
560
|
+
].cat.add_categories(feature.default_value)
|
561
|
+
self._dataset[feature.name] = self._dataset[feature.name].fillna(
|
562
|
+
feature.default_value
|
563
|
+
)
|
564
|
+
|
565
|
+
def _cat_manager_validate(self) -> None:
|
566
|
+
self.cat.validate()
|
567
|
+
|
568
|
+
if self.cat._is_validated:
|
569
|
+
self._is_validated = True
|
570
|
+
else:
|
571
|
+
self._is_validated = False
|
572
|
+
raise ValidationError(self.cat._validate_category_error_messages)
|
573
|
+
|
574
|
+
@doc_args(VALIDATE_DOCSTRING)
|
575
|
+
def validate(self) -> None:
|
576
|
+
"""{}""" # noqa: D415
|
577
|
+
if self._schema.n > 0:
|
578
|
+
try:
|
579
|
+
# first validate through pandera
|
580
|
+
self._pandera_schema.validate(self._dataset)
|
581
|
+
# then validate lamindb categoricals
|
582
|
+
self._cat_manager_validate()
|
583
|
+
except pandera.errors.SchemaError as err:
|
584
|
+
self._is_validated = False
|
585
|
+
# .exconly() doesn't exist on SchemaError
|
586
|
+
raise ValidationError(str(err)) from err
|
587
|
+
else:
|
588
|
+
self._cat_manager_validate()
|
589
|
+
|
590
|
+
@doc_args(SAVE_ARTIFACT_DOCSTRING)
|
591
|
+
def save_artifact(
|
592
|
+
self,
|
593
|
+
*,
|
594
|
+
key: str | None = None,
|
595
|
+
description: str | None = None,
|
596
|
+
revises: Artifact | None = None,
|
597
|
+
run: Run | None = None,
|
598
|
+
) -> Artifact:
|
599
|
+
"""{}""" # noqa: D415
|
600
|
+
if not self._is_validated:
|
601
|
+
self.validate() # raises ValidationError if doesn't validate
|
602
|
+
if self._artifact is None:
|
603
|
+
self._artifact = Artifact.from_df(
|
604
|
+
self._dataset,
|
605
|
+
key=key,
|
606
|
+
description=description,
|
607
|
+
revises=revises,
|
608
|
+
run=run,
|
609
|
+
format=".csv" if key.endswith(".csv") else None,
|
610
|
+
)
|
611
|
+
self._artifact.schema = self._schema
|
612
|
+
self._artifact.save()
|
613
|
+
return annotate_artifact( # type: ignore
|
614
|
+
self._artifact,
|
615
|
+
cat_vectors=self.cat._cat_vectors,
|
616
|
+
)
|
617
|
+
|
618
|
+
|
619
|
+
class AnnDataCurator(SlotsCurator):
|
620
|
+
"""Curator for `AnnData`.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
dataset: The AnnData-like object to validate & annotate.
|
624
|
+
schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
|
625
|
+
|
626
|
+
Example:
|
627
|
+
|
628
|
+
See :meth:`~lamindb.Artifact.from_anndata`.
|
629
|
+
|
630
|
+
"""
|
631
|
+
|
632
|
+
def __init__(
|
633
|
+
self,
|
634
|
+
dataset: AnnData | Artifact,
|
635
|
+
schema: Schema,
|
636
|
+
) -> None:
|
637
|
+
super().__init__(dataset=dataset, schema=schema)
|
638
|
+
if not data_is_anndata(self._dataset):
|
639
|
+
raise InvalidArgument("dataset must be AnnData-like.")
|
640
|
+
if schema.otype != "AnnData":
|
641
|
+
raise InvalidArgument("Schema otype must be 'AnnData'.")
|
642
|
+
self._slots = {
|
643
|
+
slot: DataFrameCurator(
|
644
|
+
(
|
645
|
+
getattr(self._dataset, slot.strip(".T")).T
|
646
|
+
if slot == "var.T"
|
647
|
+
or (
|
648
|
+
# backward compat
|
649
|
+
slot == "var"
|
650
|
+
and schema.slots["var"].itype not in {None, "Feature"}
|
651
|
+
)
|
652
|
+
else getattr(self._dataset, slot)
|
653
|
+
),
|
654
|
+
slot_schema,
|
655
|
+
slot=slot,
|
656
|
+
)
|
657
|
+
for slot, slot_schema in schema.slots.items()
|
658
|
+
if slot in {"obs", "var", "var.T", "uns"}
|
659
|
+
}
|
660
|
+
if "var" in self._slots and schema.slots["var"].itype not in {None, "Feature"}:
|
661
|
+
logger.warning(
|
662
|
+
"auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
|
663
|
+
)
|
664
|
+
self._slots["var"].cat._cat_vectors["var_index"] = self._slots[
|
665
|
+
"var"
|
666
|
+
].cat._cat_vectors.pop("columns")
|
667
|
+
self._slots["var"].cat._cat_vectors["var_index"]._key = "var_index"
|
668
|
+
|
669
|
+
|
670
|
+
def _assign_var_fields_categoricals_multimodal(
|
671
|
+
modality: str | None,
|
672
|
+
slot_type: str,
|
673
|
+
slot: str,
|
674
|
+
slot_schema: Schema,
|
675
|
+
var_fields: dict[str, FieldAttr],
|
676
|
+
cat_vectors: dict[str, dict[str, CatVector]],
|
677
|
+
slots: dict[str, DataFrameCurator],
|
678
|
+
) -> None:
|
679
|
+
"""Assigns var_fields and categoricals for multimodal data curators."""
|
680
|
+
if modality is not None:
|
681
|
+
# Makes sure that all tables are present
|
682
|
+
var_fields[modality] = None
|
683
|
+
cat_vectors[modality] = {}
|
684
|
+
|
685
|
+
if slot_type == "var":
|
686
|
+
var_field = parse_cat_dtype(slot_schema.itype, is_itype=True)["field"]
|
687
|
+
if modality is None:
|
688
|
+
# This should rarely/never be used since tables should have different var fields
|
689
|
+
var_fields[slot] = var_field # pragma: no cover
|
690
|
+
else:
|
691
|
+
# Note that this is NOT nested since the nested key is always "var"
|
692
|
+
var_fields[modality] = var_field
|
693
|
+
else:
|
694
|
+
obs_fields = slots[slot].cat._cat_vectors
|
695
|
+
if modality is None:
|
696
|
+
cat_vectors[slot] = obs_fields
|
697
|
+
else:
|
698
|
+
# Note that this is NOT nested since the nested key is always "obs"
|
699
|
+
cat_vectors[modality] = obs_fields
|
700
|
+
|
701
|
+
|
702
|
+
class MuDataCurator(SlotsCurator):
|
703
|
+
"""Curator for `MuData`.
|
704
|
+
|
705
|
+
Args:
|
706
|
+
dataset: The MuData-like object to validate & annotate.
|
707
|
+
schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
|
708
|
+
|
709
|
+
Example:
|
710
|
+
|
711
|
+
.. literalinclude:: scripts/curate-mudata.py
|
712
|
+
:language: python
|
713
|
+
:caption: curate-mudata.py
|
714
|
+
"""
|
715
|
+
|
716
|
+
def __init__(
|
717
|
+
self,
|
718
|
+
dataset: MuData | Artifact,
|
719
|
+
schema: Schema,
|
720
|
+
) -> None:
|
721
|
+
super().__init__(dataset=dataset, schema=schema)
|
722
|
+
if not data_is_mudata(self._dataset):
|
723
|
+
raise InvalidArgument("dataset must be MuData-like.")
|
724
|
+
if schema.otype != "MuData":
|
725
|
+
raise InvalidArgument("Schema otype must be 'MuData'.")
|
726
|
+
|
727
|
+
for slot, slot_schema in schema.slots.items():
|
728
|
+
if ":" in slot:
|
729
|
+
modality, modality_slot = slot.split(":")
|
730
|
+
schema_dataset = self._dataset.__getitem__(modality)
|
731
|
+
else:
|
732
|
+
modality, modality_slot = None, slot
|
733
|
+
schema_dataset = self._dataset
|
734
|
+
if modality_slot == "var" and schema.slots[slot].itype not in {
|
735
|
+
None,
|
736
|
+
"Feature",
|
737
|
+
}:
|
738
|
+
logger.warning(
|
739
|
+
"auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
|
740
|
+
)
|
741
|
+
self._slots[slot] = DataFrameCurator(
|
742
|
+
(
|
743
|
+
getattr(schema_dataset, modality_slot.rstrip(".T")).T
|
744
|
+
if modality_slot == "var.T"
|
745
|
+
or (
|
746
|
+
# backward compat
|
747
|
+
modality_slot == "var"
|
748
|
+
and schema.slots[slot].itype not in {None, "Feature"}
|
749
|
+
)
|
750
|
+
else getattr(schema_dataset, modality_slot)
|
751
|
+
),
|
752
|
+
slot_schema,
|
753
|
+
)
|
754
|
+
_assign_var_fields_categoricals_multimodal(
|
755
|
+
modality=modality,
|
756
|
+
slot_type=modality_slot,
|
757
|
+
slot=slot,
|
758
|
+
slot_schema=slot_schema,
|
759
|
+
var_fields=self._var_fields,
|
760
|
+
cat_vectors=self._cat_vectors,
|
761
|
+
slots=self._slots,
|
762
|
+
)
|
763
|
+
self._columns_field = self._var_fields
|
764
|
+
|
765
|
+
|
766
|
+
class SpatialDataCurator(SlotsCurator):
|
767
|
+
"""Curator for `SpatialData`.
|
768
|
+
|
769
|
+
Args:
|
770
|
+
dataset: The SpatialData-like object to validate & annotate.
|
771
|
+
schema: A :class:`~lamindb.Schema` object that defines the validation constraints.
|
772
|
+
|
773
|
+
Example:
|
774
|
+
|
775
|
+
See :meth:`~lamindb.Artifact.from_spatialdata`.
|
776
|
+
"""
|
777
|
+
|
778
|
+
def __init__(
|
779
|
+
self,
|
780
|
+
dataset: SpatialData | Artifact,
|
781
|
+
schema: Schema,
|
782
|
+
*,
|
783
|
+
sample_metadata_key: str | None = "sample",
|
784
|
+
) -> None:
|
785
|
+
super().__init__(dataset=dataset, schema=schema)
|
786
|
+
if not data_is_spatialdata(self._dataset):
|
787
|
+
raise InvalidArgument("dataset must be SpatialData-like.")
|
788
|
+
if schema.otype != "SpatialData":
|
789
|
+
raise InvalidArgument("Schema otype must be 'SpatialData'.")
|
790
|
+
|
791
|
+
for slot, slot_schema in schema.slots.items():
|
792
|
+
split_result = slot.split(":")
|
793
|
+
if (len(split_result) == 2 and split_result[0] == "table") or (
|
794
|
+
len(split_result) == 3 and split_result[0] == "tables"
|
795
|
+
):
|
796
|
+
if len(split_result) == 2:
|
797
|
+
table_key, sub_slot = split_result
|
798
|
+
logger.warning(
|
799
|
+
f"please prefix slot {slot} with 'tables:' going forward"
|
800
|
+
)
|
801
|
+
else:
|
802
|
+
table_key, sub_slot = split_result[1], split_result[2]
|
803
|
+
slot_object = self._dataset.tables.__getitem__(table_key)
|
804
|
+
if sub_slot == "var" and schema.slots[slot].itype not in {
|
805
|
+
None,
|
806
|
+
"Feature",
|
807
|
+
}:
|
808
|
+
logger.warning(
|
809
|
+
"auto-transposed `var` for backward compat, please indicate transposition in the schema definition by calling out `.T`: slots={'var.T': itype=bt.Gene.ensembl_gene_id}"
|
810
|
+
)
|
811
|
+
data_object = (
|
812
|
+
getattr(slot_object, sub_slot.rstrip(".T")).T
|
813
|
+
if sub_slot == "var.T"
|
814
|
+
or (
|
815
|
+
# backward compat
|
816
|
+
sub_slot == "var"
|
817
|
+
and schema.slots[slot].itype not in {None, "Feature"}
|
818
|
+
)
|
819
|
+
else getattr(slot_object, sub_slot)
|
820
|
+
)
|
821
|
+
elif len(split_result) == 1 or (
|
822
|
+
len(split_result) > 1 and split_result[0] == "attrs"
|
823
|
+
):
|
824
|
+
table_key = None
|
825
|
+
if len(split_result) == 1:
|
826
|
+
if split_result[0] != "attrs":
|
827
|
+
logger.warning(
|
828
|
+
f"please prefix slot {slot} with 'attrs:' going forward"
|
829
|
+
)
|
830
|
+
sub_slot = slot
|
831
|
+
data_object = self._dataset.attrs[slot]
|
832
|
+
else:
|
833
|
+
sub_slot = "attrs"
|
834
|
+
data_object = self._dataset.attrs
|
835
|
+
elif len(split_result) == 2:
|
836
|
+
sub_slot = split_result[1]
|
837
|
+
data_object = self._dataset.attrs[split_result[1]]
|
838
|
+
data_object = pd.DataFrame([data_object])
|
839
|
+
self._slots[slot] = DataFrameCurator(data_object, slot_schema, slot)
|
840
|
+
_assign_var_fields_categoricals_multimodal(
|
841
|
+
modality=table_key,
|
842
|
+
slot_type=sub_slot,
|
843
|
+
slot=slot,
|
844
|
+
slot_schema=slot_schema,
|
845
|
+
var_fields=self._var_fields,
|
846
|
+
cat_vectors=self._cat_vectors,
|
847
|
+
slots=self._slots,
|
848
|
+
)
|
849
|
+
self._columns_field = self._var_fields
|
850
|
+
|
851
|
+
|
852
|
+
class CatVector:
|
853
|
+
"""Vector with categorical values."""
|
854
|
+
|
855
|
+
def __init__(
|
856
|
+
self,
|
857
|
+
values_getter: Callable
|
858
|
+
| Iterable[str], # A callable or iterable that returns the values to validate.
|
859
|
+
field: FieldAttr, # The field to validate against.
|
860
|
+
key: str, # The name of the vector to validate. Only used for logging.
|
861
|
+
values_setter: Callable | None = None, # A callable that sets the values.
|
862
|
+
source: Record | None = None, # The ontology source to validate against.
|
863
|
+
feature: Feature | None = None,
|
864
|
+
cat_manager: DataFrameCatManager | None = None,
|
865
|
+
subtype_str: str = "",
|
866
|
+
maximal_set: bool = True, # whether unvalidated categoricals cause validation failure.
|
867
|
+
) -> None:
|
868
|
+
self._values_getter = values_getter
|
869
|
+
self._values_setter = values_setter
|
870
|
+
self._field = field
|
871
|
+
self._key = key
|
872
|
+
self._source = source
|
873
|
+
self._organism = None
|
874
|
+
self._validated: None | list[str] = None
|
875
|
+
self._non_validated: None | list[str] = None
|
876
|
+
self._synonyms: None | dict[str, str] = None
|
877
|
+
self._subtype_str = subtype_str
|
878
|
+
self._subtype_query_set = None
|
879
|
+
self._cat_manager = cat_manager
|
880
|
+
self.feature = feature
|
881
|
+
self.records = None
|
882
|
+
self._maximal_set = maximal_set
|
883
|
+
if hasattr(field.field.model, "_name_field"):
|
884
|
+
label_ref_is_name = field.field.name == field.field.model._name_field
|
885
|
+
else:
|
886
|
+
label_ref_is_name = field.field.name == "name"
|
887
|
+
self.label_ref_is_name = label_ref_is_name
|
888
|
+
|
889
|
+
@property
|
890
|
+
def values(self):
|
891
|
+
"""Get the current values using the getter function."""
|
892
|
+
if callable(self._values_getter):
|
893
|
+
return self._values_getter()
|
894
|
+
return self._values_getter
|
895
|
+
|
896
|
+
@values.setter
|
897
|
+
def values(self, new_values):
|
898
|
+
"""Set new values using the setter function if available."""
|
899
|
+
if callable(self._values_setter):
|
900
|
+
self._values_setter(new_values)
|
901
|
+
else:
|
902
|
+
# If values_getter is not callable, it's a direct reference we can update
|
903
|
+
self._values_getter = new_values
|
904
|
+
|
905
|
+
@property
|
906
|
+
def is_validated(self) -> bool:
|
907
|
+
"""Whether the vector is validated."""
|
908
|
+
# if nothing was validated, something likely is fundamentally wrong
|
909
|
+
# should probably add a setting `at_least_one_validated`
|
910
|
+
result = True
|
911
|
+
if len(self.values) > 0 and len(self.values) == len(self._non_validated):
|
912
|
+
result = False
|
913
|
+
# len(self._non_validated) != 0
|
914
|
+
# if maximal_set is True, return False
|
915
|
+
# if maximal_set is False, return True
|
916
|
+
# len(self._non_validated) == 0
|
917
|
+
# return True
|
918
|
+
if len(self._non_validated) != 0:
|
919
|
+
if self._maximal_set:
|
920
|
+
result = False
|
921
|
+
return result
|
922
|
+
|
923
|
+
def _replace_synonyms(self) -> list[str]:
|
924
|
+
"""Replace synonyms in the vector with standardized values."""
|
925
|
+
syn_mapper = self._synonyms
|
926
|
+
# replace the values in df
|
927
|
+
std_values = self.values.map(
|
928
|
+
lambda unstd_val: syn_mapper.get(unstd_val, unstd_val)
|
929
|
+
)
|
930
|
+
# remove the standardized values from self.non_validated
|
931
|
+
non_validated = [i for i in self._non_validated if i not in syn_mapper]
|
932
|
+
if len(non_validated) == 0:
|
933
|
+
self._non_validated = []
|
934
|
+
else:
|
935
|
+
self._non_validated = non_validated # type: ignore
|
936
|
+
# logging
|
937
|
+
n = len(syn_mapper)
|
938
|
+
if n > 0:
|
939
|
+
syn_mapper_print = _format_values(
|
940
|
+
[f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
|
941
|
+
)
|
942
|
+
s = "s" if n > 1 else ""
|
943
|
+
logger.success(
|
944
|
+
f'standardized {n} synonym{s} in "{self._key}": {colors.green(syn_mapper_print)}'
|
945
|
+
)
|
946
|
+
return std_values
|
947
|
+
|
948
|
+
def __repr__(self) -> str:
|
949
|
+
if self._non_validated is None:
|
950
|
+
status = "unvalidated"
|
951
|
+
else:
|
952
|
+
status = (
|
953
|
+
"validated"
|
954
|
+
if len(self._non_validated) == 0
|
955
|
+
else f"non-validated ({len(self._non_validated)})"
|
956
|
+
)
|
957
|
+
|
958
|
+
field_name = getattr(self._field, "name", str(self._field))
|
959
|
+
values_count = len(self.values) if hasattr(self.values, "__len__") else "?"
|
960
|
+
return f"CatVector(key='{self._key}', field='{field_name}', values={values_count}, {status})"
|
961
|
+
|
962
|
+
def _add_validated(self) -> tuple[list, list]:
|
963
|
+
"""Save features or labels records in the default instance."""
|
964
|
+
from lamindb.models.save import save as ln_save
|
965
|
+
|
966
|
+
registry = self._field.field.model
|
967
|
+
field_name = self._field.field.name
|
968
|
+
model_field = registry.__get_name_with_module__()
|
969
|
+
filter_kwargs = get_current_filter_kwargs(
|
970
|
+
registry, {"organism": self._organism, "source": self._source}
|
971
|
+
)
|
972
|
+
values = [i for i in self.values if isinstance(i, str) and i]
|
973
|
+
if not values:
|
974
|
+
return [], []
|
975
|
+
# inspect the default instance and save validated records from public
|
976
|
+
if (
|
977
|
+
self._subtype_str != "" and "__" not in self._subtype_str
|
978
|
+
): # not for general filter expressions
|
979
|
+
self._subtype_query_set = registry.get(name=self._subtype_str).records.all()
|
980
|
+
values_array = np.array(values)
|
981
|
+
validated_mask = self._subtype_query_set.validate( # type: ignore
|
982
|
+
values_array, field=self._field, **filter_kwargs, mute=True
|
983
|
+
)
|
984
|
+
validated_labels, non_validated_labels = (
|
985
|
+
values_array[validated_mask],
|
986
|
+
values_array[~validated_mask],
|
987
|
+
)
|
988
|
+
records = registry.from_values(
|
989
|
+
validated_labels, field=self._field, **filter_kwargs, mute=True
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
existing_and_public_records = registry.from_values(
|
993
|
+
list(values), field=self._field, **filter_kwargs, mute=True
|
994
|
+
)
|
995
|
+
existing_and_public_labels = [
|
996
|
+
getattr(r, field_name) for r in existing_and_public_records
|
997
|
+
]
|
998
|
+
# public records that are not already in the database
|
999
|
+
public_records = [r for r in existing_and_public_records if r._state.adding]
|
1000
|
+
# here we check to only save the public records if they are from the specified source
|
1001
|
+
# we check the uid because r.source and source can be from different instances
|
1002
|
+
if self._source:
|
1003
|
+
public_records = [
|
1004
|
+
r for r in public_records if r.source.uid == self._source.uid
|
1005
|
+
]
|
1006
|
+
if len(public_records) > 0:
|
1007
|
+
logger.info(f"saving validated records of '{self._key}'")
|
1008
|
+
ln_save(public_records)
|
1009
|
+
labels_saved_public = [getattr(r, field_name) for r in public_records]
|
1010
|
+
# log the saved public labels
|
1011
|
+
# the term "transferred" stresses that this is always in the context of transferring
|
1012
|
+
# labels from a public ontology or a different instance to the present instance
|
1013
|
+
if len(labels_saved_public) > 0:
|
1014
|
+
s = "s" if len(labels_saved_public) > 1 else ""
|
1015
|
+
logger.success(
|
1016
|
+
f'added {len(labels_saved_public)} record{s} {colors.green("from_public")} with {model_field} for "{self._key}": {_format_values(labels_saved_public)}'
|
1017
|
+
)
|
1018
|
+
# non-validated records from the default instance
|
1019
|
+
non_validated_labels = [
|
1020
|
+
i for i in values if i not in existing_and_public_labels
|
1021
|
+
]
|
1022
|
+
validated_labels = existing_and_public_labels
|
1023
|
+
records = existing_and_public_records
|
1024
|
+
|
1025
|
+
self.records = records
|
1026
|
+
# validated, non-validated
|
1027
|
+
return validated_labels, non_validated_labels
|
1028
|
+
|
1029
|
+
def _add_new(
|
1030
|
+
self,
|
1031
|
+
values: list[str],
|
1032
|
+
df: pd.DataFrame | None = None, # remove when all users use schema
|
1033
|
+
dtype: str | None = None,
|
1034
|
+
**create_kwargs,
|
1035
|
+
) -> None:
|
1036
|
+
"""Add new labels to the registry."""
|
1037
|
+
from lamindb.models.save import save as ln_save
|
1038
|
+
|
1039
|
+
registry = self._field.field.model
|
1040
|
+
field_name = self._field.field.name
|
1041
|
+
non_validated_records: RecordList[Any] = [] # type: ignore
|
1042
|
+
if df is not None and registry == Feature:
|
1043
|
+
nonval_columns = Feature.inspect(df.columns, mute=True).non_validated
|
1044
|
+
non_validated_records = Feature.from_df(df.loc[:, nonval_columns])
|
1045
|
+
else:
|
1046
|
+
if (
|
1047
|
+
self._organism
|
1048
|
+
and hasattr(registry, "organism")
|
1049
|
+
and registry._meta.get_field("organism").is_relation
|
1050
|
+
):
|
1051
|
+
# make sure organism record is saved to the current instance
|
1052
|
+
create_kwargs["organism"] = _save_organism(name=self._organism)
|
1053
|
+
|
1054
|
+
for value in values:
|
1055
|
+
init_kwargs = {field_name: value}
|
1056
|
+
if registry == Feature:
|
1057
|
+
init_kwargs["dtype"] = "cat" if dtype is None else dtype
|
1058
|
+
non_validated_records.append(registry(**init_kwargs, **create_kwargs))
|
1059
|
+
if len(non_validated_records) > 0:
|
1060
|
+
ln_save(non_validated_records)
|
1061
|
+
model_field = colors.italic(registry.__get_name_with_module__())
|
1062
|
+
s = "s" if len(values) > 1 else ""
|
1063
|
+
logger.success(
|
1064
|
+
f'added {len(values)} record{s} with {model_field} for "{self._key}": {_format_values(values)}'
|
1065
|
+
)
|
1066
|
+
|
1067
|
+
def _validate(
|
1068
|
+
self,
|
1069
|
+
values: list[str],
|
1070
|
+
) -> tuple[list[str], dict]:
|
1071
|
+
"""Validate ontology terms using LaminDB registries."""
|
1072
|
+
registry = self._field.field.model
|
1073
|
+
field_name = self._field.field.name
|
1074
|
+
model_field = f"{registry.__name__}.{field_name}"
|
1075
|
+
|
1076
|
+
kwargs_current = get_current_filter_kwargs(
|
1077
|
+
registry, {"organism": self._organism, "source": self._source}
|
1078
|
+
)
|
1079
|
+
|
1080
|
+
# inspect values from the default instance, excluding public
|
1081
|
+
registry_or_queryset = registry
|
1082
|
+
if self._subtype_query_set is not None:
|
1083
|
+
registry_or_queryset = self._subtype_query_set
|
1084
|
+
inspect_result = registry_or_queryset.inspect(
|
1085
|
+
values, field=self._field, mute=True, from_source=False, **kwargs_current
|
1086
|
+
)
|
1087
|
+
non_validated = inspect_result.non_validated
|
1088
|
+
syn_mapper = inspect_result.synonyms_mapper
|
1089
|
+
|
1090
|
+
# inspect the non-validated values from public (BioRecord only)
|
1091
|
+
values_validated = []
|
1092
|
+
if hasattr(registry, "public"):
|
1093
|
+
public_records = registry.from_values(
|
1094
|
+
non_validated,
|
1095
|
+
field=self._field,
|
1096
|
+
mute=True,
|
1097
|
+
**kwargs_current,
|
1098
|
+
)
|
1099
|
+
values_validated += [getattr(r, field_name) for r in public_records]
|
1100
|
+
|
1101
|
+
# logging messages
|
1102
|
+
if self._cat_manager is not None:
|
1103
|
+
slot = self._cat_manager._slot
|
1104
|
+
else:
|
1105
|
+
slot = None
|
1106
|
+
in_slot = f" in slot '{slot}'" if slot is not None else ""
|
1107
|
+
slot_prefix = f".slots['{slot}']" if slot is not None else ""
|
1108
|
+
non_validated_hint_print = (
|
1109
|
+
f"curator{slot_prefix}.cat.add_new_from('{self._key}')"
|
1110
|
+
)
|
1111
|
+
non_validated = [i for i in non_validated if i not in values_validated]
|
1112
|
+
n_non_validated = len(non_validated)
|
1113
|
+
if n_non_validated == 0:
|
1114
|
+
logger.success(
|
1115
|
+
f'"{self._key}" is validated against {colors.italic(model_field)}'
|
1116
|
+
)
|
1117
|
+
return [], {}
|
1118
|
+
else:
|
1119
|
+
s = "" if n_non_validated == 1 else "s"
|
1120
|
+
print_values = _format_values(non_validated)
|
1121
|
+
warning_message = f"{colors.red(f'{n_non_validated} term{s}')} not validated in feature '{self._key}'{in_slot}: {colors.red(print_values)}\n"
|
1122
|
+
if syn_mapper:
|
1123
|
+
s = "" if len(syn_mapper) == 1 else "s"
|
1124
|
+
syn_mapper_print = _format_values(
|
1125
|
+
[f'"{k}" → "{v}"' for k, v in syn_mapper.items()], sep=""
|
1126
|
+
)
|
1127
|
+
hint_msg = f'.standardize("{self._key}")'
|
1128
|
+
warning_message += f" {colors.yellow(f'{len(syn_mapper)} synonym{s}')} found: {colors.yellow(syn_mapper_print)}\n → curate synonyms via: {colors.cyan(hint_msg)}"
|
1129
|
+
if n_non_validated > len(syn_mapper):
|
1130
|
+
if syn_mapper:
|
1131
|
+
warning_message += "\n for remaining terms:\n"
|
1132
|
+
warning_message += f" → fix typos, remove non-existent values, or save terms via: {colors.cyan(non_validated_hint_print)}"
|
1133
|
+
if self._subtype_query_set is not None:
|
1134
|
+
warning_message += f"\n → a valid label for subtype '{self._subtype_str}' has to be one of {self._subtype_query_set.list('name')}"
|
1135
|
+
logger.info(f'mapping "{self._key}" on {colors.italic(model_field)}')
|
1136
|
+
logger.warning(warning_message)
|
1137
|
+
if self._cat_manager is not None:
|
1138
|
+
self._cat_manager._validate_category_error_messages = strip_ansi_codes(
|
1139
|
+
warning_message
|
1140
|
+
)
|
1141
|
+
return non_validated, syn_mapper
|
1142
|
+
|
1143
|
+
def validate(self) -> None:
|
1144
|
+
"""Validate the vector."""
|
1145
|
+
# add source-validated values to the registry
|
1146
|
+
self._validated, self._non_validated = self._add_validated()
|
1147
|
+
self._non_validated, self._synonyms = self._validate(values=self._non_validated)
|
1148
|
+
|
1149
|
+
# always register new Features if they are columns
|
1150
|
+
if self._key == "columns" and self._field == Feature.name:
|
1151
|
+
self.add_new()
|
1152
|
+
|
1153
|
+
def standardize(self) -> None:
|
1154
|
+
"""Standardize the vector."""
|
1155
|
+
registry = self._field.field.model
|
1156
|
+
if not hasattr(registry, "standardize"):
|
1157
|
+
return self.values
|
1158
|
+
if self._synonyms is None:
|
1159
|
+
self.validate()
|
1160
|
+
# get standardized values
|
1161
|
+
std_values = self._replace_synonyms()
|
1162
|
+
# update non_validated values
|
1163
|
+
self._non_validated = [
|
1164
|
+
i for i in self._non_validated if i not in self._synonyms.keys()
|
1165
|
+
]
|
1166
|
+
# remove synonyms since they are now standardized
|
1167
|
+
self._synonyms = {}
|
1168
|
+
# update the values with the standardized values
|
1169
|
+
self.values = std_values
|
1170
|
+
|
1171
|
+
def add_new(self, **create_kwargs) -> None:
|
1172
|
+
"""Add new values to the registry."""
|
1173
|
+
if self._non_validated is None:
|
1174
|
+
self.validate()
|
1175
|
+
if len(self._synonyms) > 0:
|
1176
|
+
# raise error because .standardize modifies the input dataset
|
1177
|
+
raise ValidationError(
|
1178
|
+
"Please run `.standardize()` before adding new values."
|
1179
|
+
)
|
1180
|
+
self._add_new(
|
1181
|
+
values=self._non_validated,
|
1182
|
+
**create_kwargs,
|
1183
|
+
)
|
1184
|
+
# remove the non_validated values since they are now registered
|
1185
|
+
self._non_validated = []
|
1186
|
+
|
1187
|
+
|
1188
|
+
class DataFrameCatManager:
|
1189
|
+
"""Manage categoricals by updating registries.
|
1190
|
+
|
1191
|
+
This class is accessible from within a `DataFrameCurator` via the `.cat` attribute.
|
1192
|
+
|
1193
|
+
If you find non-validated values, you have two options:
|
1194
|
+
|
1195
|
+
- new values found in the data can be registered via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.add_new_from`
|
1196
|
+
- non-validated values can be accessed via `DataFrameCurator.cat.add_new_from()` :meth:`~lamindb.curators.core.DataFrameCatManager.non_validated` and addressed manually
|
1197
|
+
"""
|
1198
|
+
|
1199
|
+
def __init__(
|
1200
|
+
self,
|
1201
|
+
df: pd.DataFrame | Artifact,
|
1202
|
+
columns_field: FieldAttr = Feature.name,
|
1203
|
+
columns_names: Iterable[str] | None = None,
|
1204
|
+
categoricals: list[Feature] | None = None,
|
1205
|
+
sources: dict[str, Record] | None = None,
|
1206
|
+
index: Feature | None = None,
|
1207
|
+
slot: str | None = None,
|
1208
|
+
maximal_set: bool = False,
|
1209
|
+
) -> None:
|
1210
|
+
self._non_validated = None
|
1211
|
+
self._index = index
|
1212
|
+
self._artifact: Artifact = None # pass the dataset as an artifact
|
1213
|
+
self._dataset: Any = df # pass the dataset as a UPathStr or data object
|
1214
|
+
if isinstance(self._dataset, Artifact):
|
1215
|
+
self._artifact = self._dataset
|
1216
|
+
self._dataset = self._dataset.load(is_run_input=False)
|
1217
|
+
self._is_validated: bool = False
|
1218
|
+
self._categoricals = categoricals or []
|
1219
|
+
self._non_validated = None
|
1220
|
+
self._sources = sources or {}
|
1221
|
+
self._columns_field = columns_field
|
1222
|
+
self._validate_category_error_messages: str = ""
|
1223
|
+
self._cat_vectors: dict[str, CatVector] = {}
|
1224
|
+
self._slot = slot
|
1225
|
+
self._maximal_set = maximal_set
|
1226
|
+
|
1227
|
+
if columns_names is None:
|
1228
|
+
columns_names = []
|
1229
|
+
if columns_field == Feature.name:
|
1230
|
+
self._cat_vectors["columns"] = CatVector(
|
1231
|
+
values_getter=columns_names,
|
1232
|
+
field=columns_field,
|
1233
|
+
key="columns" if isinstance(self._dataset, pd.DataFrame) else "keys",
|
1234
|
+
source=self._sources.get("columns"),
|
1235
|
+
cat_manager=self,
|
1236
|
+
maximal_set=self._maximal_set,
|
1237
|
+
)
|
1238
|
+
else:
|
1239
|
+
self._cat_vectors["columns"] = CatVector(
|
1240
|
+
values_getter=lambda: self._dataset.columns, # lambda ensures the inplace update
|
1241
|
+
values_setter=lambda new_values: setattr(
|
1242
|
+
self._dataset, "columns", pd.Index(new_values)
|
1243
|
+
),
|
1244
|
+
field=columns_field,
|
1245
|
+
key="columns",
|
1246
|
+
source=self._sources.get("columns"),
|
1247
|
+
cat_manager=self,
|
1248
|
+
maximal_set=self._maximal_set,
|
1249
|
+
)
|
1250
|
+
for feature in self._categoricals:
|
1251
|
+
result = parse_dtype(feature.dtype)[
|
1252
|
+
0
|
1253
|
+
] # TODO: support composite dtypes for categoricals
|
1254
|
+
key = feature.name
|
1255
|
+
field = result["field"]
|
1256
|
+
subtype_str = result["subtype_str"]
|
1257
|
+
self._cat_vectors[key] = CatVector(
|
1258
|
+
values_getter=lambda k=key: self._dataset[
|
1259
|
+
k
|
1260
|
+
], # Capture key as default argument
|
1261
|
+
values_setter=lambda new_values, k=key: self._dataset.__setitem__(
|
1262
|
+
k, new_values
|
1263
|
+
),
|
1264
|
+
field=field,
|
1265
|
+
key=key,
|
1266
|
+
source=self._sources.get(key),
|
1267
|
+
feature=feature,
|
1268
|
+
cat_manager=self,
|
1269
|
+
subtype_str=subtype_str,
|
1270
|
+
)
|
1271
|
+
if index is not None and index.dtype.startswith("cat"):
|
1272
|
+
result = parse_dtype(index.dtype)[0]
|
1273
|
+
field = result["field"]
|
1274
|
+
key = "index"
|
1275
|
+
self._cat_vectors[key] = CatVector(
|
1276
|
+
values_getter=self._dataset.index,
|
1277
|
+
field=field,
|
1278
|
+
key=key,
|
1279
|
+
feature=index,
|
1280
|
+
cat_manager=self,
|
1281
|
+
)
|
1282
|
+
|
1283
|
+
@property
|
1284
|
+
def non_validated(self) -> dict[str, list[str]]:
|
1285
|
+
"""Return the non-validated features and labels."""
|
1286
|
+
if self._non_validated is None:
|
1287
|
+
raise ValidationError("Please run validate() first!")
|
1288
|
+
return {
|
1289
|
+
key: cat_vector._non_validated
|
1290
|
+
for key, cat_vector in self._cat_vectors.items()
|
1291
|
+
if cat_vector._non_validated and key != "columns"
|
1292
|
+
}
|
1293
|
+
|
1294
|
+
@property
|
1295
|
+
def categoricals(self) -> list[Feature]:
|
1296
|
+
"""The categorical features."""
|
1297
|
+
return self._categoricals
|
1298
|
+
|
1299
|
+
def lookup(self, public: bool = False) -> CatLookup:
|
1300
|
+
"""Lookup categories.
|
1301
|
+
|
1302
|
+
Args:
|
1303
|
+
public: If "public", the lookup is performed on the public reference.
|
1304
|
+
"""
|
1305
|
+
return CatLookup(
|
1306
|
+
categoricals=self._categoricals,
|
1307
|
+
slots={"columns": self._columns_field},
|
1308
|
+
public=public,
|
1309
|
+
sources=self._sources,
|
1310
|
+
)
|
1311
|
+
|
1312
|
+
def validate(self) -> bool:
|
1313
|
+
"""Validate variables and categorical observations."""
|
1314
|
+
self._validate_category_error_messages = "" # reset the error messages
|
1315
|
+
|
1316
|
+
validated = True
|
1317
|
+
for key, cat_vector in self._cat_vectors.items():
|
1318
|
+
logger.info(f"validating vector {key}")
|
1319
|
+
cat_vector.validate()
|
1320
|
+
validated &= cat_vector.is_validated
|
1321
|
+
self._is_validated = validated
|
1322
|
+
self._non_validated = {} # type: ignore
|
1323
|
+
|
1324
|
+
if self._index is not None:
|
1325
|
+
# cat_vector.validate() populates validated labels
|
1326
|
+
# the index should become part of the feature set corresponding to the dataframe
|
1327
|
+
if self._cat_vectors["columns"].records is not None:
|
1328
|
+
self._cat_vectors["columns"].records.insert(0, self._index) # type: ignore
|
1329
|
+
else:
|
1330
|
+
self._cat_vectors["columns"].records = [self._index] # type: ignore
|
1331
|
+
|
1332
|
+
return self._is_validated
|
1333
|
+
|
1334
|
+
def standardize(self, key: str) -> None:
|
1335
|
+
"""Replace synonyms with standardized values.
|
1336
|
+
|
1337
|
+
Modifies the input dataset inplace.
|
1338
|
+
|
1339
|
+
Args:
|
1340
|
+
key: The key referencing the column in the DataFrame to standardize.
|
1341
|
+
"""
|
1342
|
+
if self._artifact is not None:
|
1343
|
+
raise RuntimeError("can't mutate the dataset when an artifact is passed!")
|
1344
|
+
|
1345
|
+
if key == "all":
|
1346
|
+
logger.warning(
|
1347
|
+
"'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
|
1348
|
+
)
|
1349
|
+
for k in self.non_validated.keys():
|
1350
|
+
self._cat_vectors[k].standardize()
|
1351
|
+
else:
|
1352
|
+
self._cat_vectors[key].standardize()
|
1353
|
+
|
1354
|
+
def add_new_from(self, key: str, **kwargs):
|
1355
|
+
"""Add validated & new categories.
|
1356
|
+
|
1357
|
+
Args:
|
1358
|
+
key: The key referencing the slot in the DataFrame from which to draw terms.
|
1359
|
+
**kwargs: Additional keyword arguments to pass to create new records
|
1360
|
+
"""
|
1361
|
+
if len(kwargs) > 0 and key == "all":
|
1362
|
+
raise ValueError("Cannot pass additional arguments to 'all' key!")
|
1363
|
+
if key == "all":
|
1364
|
+
logger.warning(
|
1365
|
+
"'all' is deprecated, please pass a single key from `.non_validated.keys()` instead!"
|
1366
|
+
)
|
1367
|
+
for k in self.non_validated.keys():
|
1368
|
+
self._cat_vectors[k].add_new(**kwargs)
|
1369
|
+
else:
|
1370
|
+
self._cat_vectors[key].add_new(**kwargs)
|
1371
|
+
|
1372
|
+
|
1373
|
+
def get_current_filter_kwargs(registry: type[Record], kwargs: dict) -> dict:
|
1374
|
+
"""Make sure the source and organism are saved in the same database as the registry."""
|
1375
|
+
db = registry.filter().db
|
1376
|
+
source = kwargs.get("source")
|
1377
|
+
organism = kwargs.get("organism")
|
1378
|
+
filter_kwargs = kwargs.copy()
|
1379
|
+
|
1380
|
+
if isinstance(organism, Record) and organism._state.db != "default":
|
1381
|
+
if db is None or db == "default":
|
1382
|
+
organism_default = copy.copy(organism)
|
1383
|
+
# save the organism record in the default database
|
1384
|
+
organism_default.save()
|
1385
|
+
filter_kwargs["organism"] = organism_default
|
1386
|
+
if isinstance(source, Record) and source._state.db != "default":
|
1387
|
+
if db is None or db == "default":
|
1388
|
+
source_default = copy.copy(source)
|
1389
|
+
# save the source record in the default database
|
1390
|
+
source_default.save()
|
1391
|
+
filter_kwargs["source"] = source_default
|
1392
|
+
|
1393
|
+
return filter_kwargs
|
1394
|
+
|
1395
|
+
|
1396
|
+
def get_organism_kwargs(
|
1397
|
+
field: FieldAttr, organism: str | None = None, values: Any = None
|
1398
|
+
) -> dict[str, str]:
|
1399
|
+
"""Check if a registry needs an organism and return the organism name."""
|
1400
|
+
registry = field.field.model
|
1401
|
+
if registry.__base__.__name__ == "BioRecord":
|
1402
|
+
import bionty as bt
|
1403
|
+
from bionty._organism import is_organism_required
|
1404
|
+
|
1405
|
+
from ..models._from_values import get_organism_record_from_field
|
1406
|
+
|
1407
|
+
if is_organism_required(registry):
|
1408
|
+
if organism is not None or bt.settings.organism is not None:
|
1409
|
+
return {"organism": organism or bt.settings.organism.name}
|
1410
|
+
else:
|
1411
|
+
organism_record = get_organism_record_from_field(
|
1412
|
+
field, organism=organism, values=values
|
1413
|
+
)
|
1414
|
+
if organism_record is not None:
|
1415
|
+
return {"organism": organism_record.name}
|
1416
|
+
return {}
|
1417
|
+
|
1418
|
+
|
1419
|
+
def annotate_artifact(
|
1420
|
+
artifact: Artifact,
|
1421
|
+
*,
|
1422
|
+
curator: AnnDataCurator | SlotsCurator | None = None,
|
1423
|
+
cat_vectors: dict[str, CatVector] | None = None,
|
1424
|
+
) -> Artifact:
|
1425
|
+
from .. import settings
|
1426
|
+
from ..models.artifact import add_labels
|
1427
|
+
|
1428
|
+
if cat_vectors is None:
|
1429
|
+
cat_vectors = {}
|
1430
|
+
|
1431
|
+
# annotate with labels
|
1432
|
+
for key, cat_vector in cat_vectors.items():
|
1433
|
+
if (
|
1434
|
+
cat_vector._field.field.model == Feature
|
1435
|
+
or key == "columns"
|
1436
|
+
or key == "var_index"
|
1437
|
+
):
|
1438
|
+
continue
|
1439
|
+
if len(cat_vector.records) > settings.annotation.n_max_records:
|
1440
|
+
logger.important(
|
1441
|
+
f"not annotating with {len(cat_vector.records)} labels for feature {key} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
|
1442
|
+
)
|
1443
|
+
continue
|
1444
|
+
add_labels(
|
1445
|
+
artifact,
|
1446
|
+
records=cat_vector.records,
|
1447
|
+
feature=cat_vector.feature,
|
1448
|
+
feature_ref_is_name=None, # do not need anymore
|
1449
|
+
label_ref_is_name=cat_vector.label_ref_is_name,
|
1450
|
+
from_curator=True,
|
1451
|
+
)
|
1452
|
+
|
1453
|
+
# annotate with inferred schemas aka feature sets
|
1454
|
+
if artifact.otype == "DataFrame":
|
1455
|
+
features = cat_vectors["columns"].records
|
1456
|
+
if features is not None:
|
1457
|
+
feature_set = Schema(
|
1458
|
+
features=features, coerce_dtype=artifact.schema.coerce_dtype
|
1459
|
+
) # TODO: add more defaults from validating schema
|
1460
|
+
if (
|
1461
|
+
feature_set._state.adding
|
1462
|
+
and len(features) > settings.annotation.n_max_records
|
1463
|
+
):
|
1464
|
+
logger.important(
|
1465
|
+
f"not annotating with {len(features)} features as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
|
1466
|
+
)
|
1467
|
+
itype = parse_cat_dtype(artifact.schema.itype, is_itype=True)["field"]
|
1468
|
+
feature_set = Schema(itype=itype, n=len(features))
|
1469
|
+
artifact.feature_sets.add(
|
1470
|
+
feature_set.save(), through_defaults={"slot": "columns"}
|
1471
|
+
)
|
1472
|
+
else:
|
1473
|
+
for slot, slot_curator in curator._slots.items():
|
1474
|
+
# var_index is backward compat (2025-05-01)
|
1475
|
+
name = (
|
1476
|
+
"var_index"
|
1477
|
+
if (slot == "var" and "var_index" in slot_curator.cat._cat_vectors)
|
1478
|
+
else "columns"
|
1479
|
+
)
|
1480
|
+
features = slot_curator.cat._cat_vectors[name].records
|
1481
|
+
if features is None:
|
1482
|
+
logger.warning(f"no features found for slot {slot}")
|
1483
|
+
continue
|
1484
|
+
itype = parse_cat_dtype(artifact.schema.slots[slot].itype, is_itype=True)[
|
1485
|
+
"field"
|
1486
|
+
]
|
1487
|
+
feature_set = Schema(features=features, itype=itype)
|
1488
|
+
if (
|
1489
|
+
feature_set._state.adding
|
1490
|
+
and len(features) > settings.annotation.n_max_records
|
1491
|
+
):
|
1492
|
+
logger.important(
|
1493
|
+
f"not annotating with {len(features)} features for slot {slot} as it exceeds {settings.annotation.n_max_records} (ln.settings.annotation.n_max_records)"
|
1494
|
+
)
|
1495
|
+
feature_set = Schema(itype=itype, n=len(features))
|
1496
|
+
artifact.feature_sets.add(
|
1497
|
+
feature_set.save(), through_defaults={"slot": slot}
|
1498
|
+
)
|
1499
|
+
|
1500
|
+
slug = ln_setup.settings.instance.slug
|
1501
|
+
if ln_setup.settings.instance.is_remote: # pdagma: no cover
|
1502
|
+
logger.important(f"go to https://lamin.ai/{slug}/artifact/{artifact.uid}")
|
1503
|
+
return artifact
|
1504
|
+
|
1505
|
+
|
1506
|
+
# TODO: need this function to support mutli-value columns
|
1507
|
+
def _flatten_unique(series: pd.Series[list[Any] | Any]) -> list[Any]:
|
1508
|
+
"""Flatten a Pandas series containing lists or single items into a unique list of elements."""
|
1509
|
+
result = set()
|
1510
|
+
|
1511
|
+
for item in series:
|
1512
|
+
if isinstance(item, list):
|
1513
|
+
result.update(item)
|
1514
|
+
else:
|
1515
|
+
result.add(item)
|
1516
|
+
|
1517
|
+
return list(result)
|
1518
|
+
|
1519
|
+
|
1520
|
+
def _save_organism(name: str):
|
1521
|
+
"""Save an organism record."""
|
1522
|
+
import bionty as bt
|
1523
|
+
|
1524
|
+
organism = bt.Organism.filter(name=name).one_or_none()
|
1525
|
+
if organism is None:
|
1526
|
+
organism = bt.Organism.from_source(name=name)
|
1527
|
+
if organism is None:
|
1528
|
+
raise ValidationError(
|
1529
|
+
f'Organism "{name}" not found from public reference\n'
|
1530
|
+
f' → please save it from a different source: bt.Organism.from_source(name="{name}", source).save()'
|
1531
|
+
f' → or manually save it without source: bt.Organism(name="{name}").save()'
|
1532
|
+
)
|
1533
|
+
organism.save()
|
1534
|
+
return organism
|