lamindb 1.11.2__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lamindb/__init__.py +8 -14
- lamindb/_tracked.py +2 -0
- lamindb/base/types.py +1 -3
- lamindb/core/_context.py +16 -31
- lamindb/core/_mapped_collection.py +2 -2
- lamindb/core/storage/paths.py +5 -3
- lamindb/curators/core.py +15 -4
- lamindb/examples/__init__.py +3 -1
- lamindb/examples/croissant/__init__.py +3 -1
- lamindb/examples/mlflow/__init__.py +38 -0
- lamindb/examples/wandb/__init__.py +40 -0
- lamindb/integrations/__init__.py +26 -0
- lamindb/integrations/lightning.py +87 -0
- lamindb/migrations/0120_add_record_fk_constraint.py +1 -1
- lamindb/migrations/0122_remove_personproject_person_and_more.py +219 -0
- lamindb/migrations/0123_alter_artifact_description_alter_branch_description_and_more.py +82 -0
- lamindb/migrations/0124_page_artifact_page_collection_page_feature_page_and_more.py +15 -0
- lamindb/migrations/0125_artifact_is_locked_collection_is_locked_and_more.py +79 -0
- lamindb/migrations/0126_alter_artifact_is_locked_alter_collection_is_locked_and_more.py +105 -0
- lamindb/migrations/0127_alter_run_status_code_feature_dtype.py +31 -0
- lamindb/migrations/0128_artifact__real_key.py +21 -0
- lamindb/migrations/0129_remove_feature_page_remove_project_page_and_more.py +779 -0
- lamindb/migrations/0130_branch_space_alter_artifactblock_artifact_and_more.py +170 -0
- lamindb/migrations/0131_record_unique_name_type_space.py +18 -0
- lamindb/migrations/0132_record_parents_record_reference_and_more.py +61 -0
- lamindb/migrations/0133_artifactuser_artifact_users.py +108 -0
- lamindb/migrations/{0119_squashed.py → 0133_squashed.py} +1211 -322
- lamindb/models/__init__.py +14 -4
- lamindb/models/_django.py +1 -2
- lamindb/models/_feature_manager.py +1 -0
- lamindb/models/_is_versioned.py +14 -16
- lamindb/models/_relations.py +7 -0
- lamindb/models/artifact.py +99 -56
- lamindb/models/artifact_set.py +20 -3
- lamindb/models/block.py +174 -0
- lamindb/models/can_curate.py +7 -9
- lamindb/models/collection.py +9 -9
- lamindb/models/feature.py +38 -38
- lamindb/models/has_parents.py +15 -6
- lamindb/models/project.py +44 -99
- lamindb/models/query_manager.py +1 -1
- lamindb/models/query_set.py +36 -8
- lamindb/models/record.py +169 -46
- lamindb/models/run.py +44 -10
- lamindb/models/save.py +7 -7
- lamindb/models/schema.py +26 -7
- lamindb/models/sqlrecord.py +87 -35
- lamindb/models/storage.py +13 -3
- lamindb/models/transform.py +7 -2
- lamindb/models/ulabel.py +6 -23
- {lamindb-1.11.2.dist-info → lamindb-1.12.0.dist-info}/METADATA +18 -21
- {lamindb-1.11.2.dist-info → lamindb-1.12.0.dist-info}/RECORD +54 -38
- {lamindb-1.11.2.dist-info → lamindb-1.12.0.dist-info}/LICENSE +0 -0
- {lamindb-1.11.2.dist-info → lamindb-1.12.0.dist-info}/WHEEL +0 -0
lamindb/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""A data
|
1
|
+
"""A data lakehouse for biology.
|
2
2
|
|
3
3
|
Data lineage
|
4
4
|
============
|
@@ -31,21 +31,20 @@ Manage artifacts and transforms.
|
|
31
31
|
Transform
|
32
32
|
Run
|
33
33
|
|
34
|
-
|
34
|
+
Create labels and manage sheets with flexible records, e.g., for samples or donors.
|
35
35
|
|
36
36
|
.. autosummary::
|
37
37
|
:toctree: .
|
38
38
|
|
39
|
-
|
40
|
-
ULabel
|
41
|
-
Schema
|
39
|
+
Record
|
42
40
|
|
43
|
-
|
41
|
+
Define features & schemas to validate artifacts & records.
|
44
42
|
|
45
43
|
.. autosummary::
|
46
44
|
:toctree: .
|
47
45
|
|
48
|
-
|
46
|
+
Feature
|
47
|
+
Schema
|
49
48
|
|
50
49
|
Manage projects.
|
51
50
|
|
@@ -58,7 +57,6 @@ Manage projects.
|
|
58
57
|
Space
|
59
58
|
Branch
|
60
59
|
Reference
|
61
|
-
Person
|
62
60
|
|
63
61
|
Other
|
64
62
|
=====
|
@@ -106,15 +104,13 @@ Backwards compatibility.
|
|
106
104
|
.. autosummary::
|
107
105
|
:toctree: .
|
108
106
|
|
109
|
-
|
110
|
-
FeatureSet
|
111
|
-
Curator
|
107
|
+
ULabel
|
112
108
|
|
113
109
|
"""
|
114
110
|
|
115
111
|
# ruff: noqa: I001
|
116
112
|
# denote a release candidate for 0.1.0 with 0.1rc1, 0.1a1, 0.1b1, etc.
|
117
|
-
__version__ = "1.
|
113
|
+
__version__ = "1.12.0"
|
118
114
|
|
119
115
|
import warnings as _warnings
|
120
116
|
|
@@ -141,7 +137,6 @@ from .models import (
|
|
141
137
|
Collection,
|
142
138
|
Feature,
|
143
139
|
FeatureSet, # backward compat
|
144
|
-
Person,
|
145
140
|
Project,
|
146
141
|
Reference,
|
147
142
|
Run,
|
@@ -188,7 +183,6 @@ __all__ = [
|
|
188
183
|
"Space",
|
189
184
|
"Branch",
|
190
185
|
"Reference",
|
191
|
-
"Person",
|
192
186
|
# other
|
193
187
|
"connect",
|
194
188
|
"view",
|
lamindb/_tracked.py
CHANGED
@@ -91,6 +91,7 @@ def tracked(uid: str | None = None) -> Callable[[Callable[P, R]], Callable[P, R]
|
|
91
91
|
|
92
92
|
run = Run(transform=transform, initiated_by_run=initiated_by_run) # type: ignore
|
93
93
|
run.started_at = datetime.now(timezone.utc)
|
94
|
+
run._status_code = -1 # started
|
94
95
|
run.save()
|
95
96
|
|
96
97
|
# Bind arguments to get a mapping of parameter names to values
|
@@ -117,6 +118,7 @@ def tracked(uid: str | None = None) -> Callable[[Callable[P, R]], Callable[P, R]
|
|
117
118
|
try:
|
118
119
|
result = func(*args, **kwargs)
|
119
120
|
run.finished_at = datetime.now(timezone.utc)
|
121
|
+
run._status_code = 0 # completed
|
120
122
|
run.save()
|
121
123
|
return result
|
122
124
|
finally:
|
lamindb/base/types.py
CHANGED
@@ -34,9 +34,7 @@ from lamindb_setup.types import UPathStr # noqa: F401
|
|
34
34
|
ListLike = Union[list[str], pd.Series, np.array]
|
35
35
|
StrField = Union[str, FieldAttr] # typing.TypeAlias
|
36
36
|
|
37
|
-
TransformType = Literal[
|
38
|
-
"pipeline", "notebook", "upload", "script", "function", "linker"
|
39
|
-
]
|
37
|
+
TransformType = Literal["pipeline", "notebook", "script", "function", "linker"]
|
40
38
|
ArtifactKind = Literal["dataset", "model", "__lamindb_run__"]
|
41
39
|
|
42
40
|
# below is used for Feature.dtype and Param.dtype
|
lamindb/core/_context.py
CHANGED
@@ -135,6 +135,12 @@ class LogStreamHandler:
|
|
135
135
|
if not self.file.closed:
|
136
136
|
self.file.flush()
|
137
137
|
|
138
|
+
# https://laminlabs.slack.com/archives/C07DB677JF6/p1759423901926139
|
139
|
+
# other tracking frameworks like W&B use our output stream and expect
|
140
|
+
# certain functions like isatty to be available
|
141
|
+
def isatty(self) -> bool:
|
142
|
+
return False
|
143
|
+
|
138
144
|
# .flush is sometimes (in jupyter etc.) called after every .write
|
139
145
|
# this needs to be called only at the end
|
140
146
|
def flush_buffer(self):
|
@@ -441,10 +447,6 @@ class Context:
|
|
441
447
|
) = self._track_source_code(path=path)
|
442
448
|
if description is None:
|
443
449
|
description = self._description
|
444
|
-
# temporarily until the hub displays the key by default
|
445
|
-
# populate the description with the filename again
|
446
|
-
if description is None:
|
447
|
-
description = self._path.name
|
448
450
|
self._create_or_load_transform(
|
449
451
|
description=description,
|
450
452
|
transform_ref=transform_ref,
|
@@ -710,8 +712,14 @@ class Context:
|
|
710
712
|
aux_transform = Transform.filter(hash=transform_hash).one_or_none()
|
711
713
|
else:
|
712
714
|
aux_transform = None
|
715
|
+
|
716
|
+
# determine the transform key
|
717
|
+
if ln_setup.settings.work_dir is not None:
|
718
|
+
key = self._path.relative_to(ln_setup.settings.work_dir).as_posix()
|
719
|
+
else:
|
720
|
+
key = self._path.name
|
713
721
|
# if the user did not pass a uid and there is no matching aux_transform
|
714
|
-
# need to search for the transform based on the
|
722
|
+
# need to search for the transform based on the key
|
715
723
|
if self.uid is None and aux_transform is None:
|
716
724
|
|
717
725
|
class SlashCount(Func):
|
@@ -720,12 +728,11 @@ class Context:
|
|
720
728
|
|
721
729
|
# we need to traverse from greater depth to shorter depth so that we match better matches first
|
722
730
|
transforms = (
|
723
|
-
Transform.filter(key__endswith=
|
731
|
+
Transform.filter(key__endswith=key, is_latest=True)
|
724
732
|
.annotate(slash_count=SlashCount("key"))
|
725
733
|
.order_by("-slash_count")
|
726
734
|
)
|
727
735
|
uid = f"{base62_12()}0000"
|
728
|
-
key = self._path.name
|
729
736
|
target_transform = None
|
730
737
|
if len(transforms) != 0:
|
731
738
|
message = ""
|
@@ -755,19 +762,6 @@ class Context:
|
|
755
762
|
# the user did pass the uid
|
756
763
|
elif self.uid is not None and len(self.uid) == 16:
|
757
764
|
transform = Transform.filter(uid=self.uid).one_or_none()
|
758
|
-
if transform is not None:
|
759
|
-
if transform.key not in self._path.as_posix():
|
760
|
-
n_parts = len(Path(transform.key).parts)
|
761
|
-
(
|
762
|
-
Path(*self._path.parts[-n_parts:]).as_posix()
|
763
|
-
if n_parts > 0
|
764
|
-
else ""
|
765
|
-
)
|
766
|
-
key = self._path.name
|
767
|
-
else:
|
768
|
-
key = transform.key # type: ignore
|
769
|
-
else:
|
770
|
-
key = self._path.name
|
771
765
|
else:
|
772
766
|
if self.uid is not None:
|
773
767
|
# the case with length 16 is covered above
|
@@ -784,10 +778,8 @@ class Context:
|
|
784
778
|
# deal with a hash-based match
|
785
779
|
# the user might have a made a copy of the notebook or script
|
786
780
|
# and actually wants to create a new transform
|
787
|
-
if aux_transform is not None and not aux_transform.key.endswith(
|
788
|
-
|
789
|
-
):
|
790
|
-
prompt = f"Found transform with same hash but different key: {aux_transform.key}. Did you rename your {transform_type} to {self._path.name} (1) or intentionally made a copy (2)?"
|
781
|
+
if aux_transform is not None and not aux_transform.key.endswith(key):
|
782
|
+
prompt = f"Found transform with same hash but different key: {aux_transform.key}. Did you rename your {transform_type} to {key} (1) or intentionally made a copy (2)?"
|
791
783
|
response = (
|
792
784
|
"1" if os.getenv("LAMIN_TESTING") == "true" else input(prompt)
|
793
785
|
)
|
@@ -800,12 +792,6 @@ class Context:
|
|
800
792
|
None,
|
801
793
|
) # make a new transform
|
802
794
|
if aux_transform is not None:
|
803
|
-
if aux_transform.key.endswith(self._path.name):
|
804
|
-
key = aux_transform.key
|
805
|
-
else:
|
806
|
-
key = "/".join(
|
807
|
-
aux_transform.key.split("/")[:-1] + [self._path.name]
|
808
|
-
)
|
809
795
|
uid, target_transform, message = self._process_aux_transform(
|
810
796
|
aux_transform, transform_hash
|
811
797
|
)
|
@@ -814,7 +800,6 @@ class Context:
|
|
814
800
|
else:
|
815
801
|
uid = f"{self.uid}0000" if self.uid is not None else None
|
816
802
|
target_transform = None
|
817
|
-
key = self._path.name
|
818
803
|
self.uid, transform = uid, target_transform
|
819
804
|
if self.version is not None:
|
820
805
|
# test inconsistent version passed
|
@@ -634,8 +634,8 @@ class MappedCollection:
|
|
634
634
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
635
635
|
self.close()
|
636
636
|
|
637
|
-
@
|
638
|
-
def torch_worker_init_fn(worker_id):
|
637
|
+
@classmethod
|
638
|
+
def torch_worker_init_fn(cls, worker_id):
|
639
639
|
"""`worker_init_fn` for `torch.utils.data.DataLoader`.
|
640
640
|
|
641
641
|
Improves performance for `num_workers > 1`.
|
lamindb/core/storage/paths.py
CHANGED
@@ -25,12 +25,14 @@ AUTO_KEY_PREFIX = ".lamindb/"
|
|
25
25
|
|
26
26
|
# add type annotations back asap when re-organizing the module
|
27
27
|
def auto_storage_key_from_artifact(artifact: Artifact):
|
28
|
-
if artifact.
|
28
|
+
if (real_key := artifact._real_key) is not None:
|
29
|
+
return real_key
|
30
|
+
key = artifact.key
|
31
|
+
if key is None or artifact._key_is_virtual:
|
29
32
|
return auto_storage_key_from_artifact_uid(
|
30
33
|
artifact.uid, artifact.suffix, artifact.overwrite_versions
|
31
34
|
)
|
32
|
-
|
33
|
-
return artifact.key
|
35
|
+
return artifact.key
|
34
36
|
|
35
37
|
|
36
38
|
def auto_storage_key_from_artifact_uid(
|
lamindb/curators/core.py
CHANGED
@@ -625,7 +625,7 @@ class ComponentCurator(Curator):
|
|
625
625
|
has_dtype_error = "WRONG_DATATYPE" in str(err)
|
626
626
|
error_msg = str(err)
|
627
627
|
if has_dtype_error:
|
628
|
-
error_msg += " ▶ Hint: Consider setting '
|
628
|
+
error_msg += " ▶ Hint: Consider setting 'coerce_dtype=True' to attempt coercing/converting values during validation to the pre-defined dtype."
|
629
629
|
raise ValidationError(error_msg) from err
|
630
630
|
else:
|
631
631
|
self._cat_manager_validate()
|
@@ -911,7 +911,7 @@ class AnnDataCurator(SlotsCurator):
|
|
911
911
|
super().__init__(dataset=dataset, schema=schema)
|
912
912
|
if not data_is_scversedatastructure(self._dataset, "AnnData"):
|
913
913
|
raise InvalidArgument("dataset must be AnnData-like.")
|
914
|
-
if schema.otype
|
914
|
+
if schema.otype != "AnnData":
|
915
915
|
raise InvalidArgument("Schema otype must be 'AnnData'.")
|
916
916
|
|
917
917
|
for slot, slot_schema in schema.slots.items():
|
@@ -1388,7 +1388,7 @@ class CatVector:
|
|
1388
1388
|
related_name = registry._meta.get_field("type").remote_field.related_name
|
1389
1389
|
type_record = registry.get(name=self._subtype_str)
|
1390
1390
|
if registry.__name__ == "Record":
|
1391
|
-
self._subtype_query_set = type_record.
|
1391
|
+
self._subtype_query_set = type_record.query_records()
|
1392
1392
|
else:
|
1393
1393
|
self._subtype_query_set = getattr(type_record, related_name).all()
|
1394
1394
|
values_array = np.array(str_values)
|
@@ -1561,7 +1561,18 @@ class CatVector:
|
|
1561
1561
|
if n_non_validated > len(syn_mapper):
|
1562
1562
|
if syn_mapper:
|
1563
1563
|
warning_message += "\n for remaining terms:\n"
|
1564
|
-
|
1564
|
+
check_organism = ""
|
1565
|
+
if registry.__base__.__name__ == "BioRecord":
|
1566
|
+
import bionty as bt
|
1567
|
+
from bionty._organism import is_organism_required
|
1568
|
+
|
1569
|
+
if is_organism_required(registry):
|
1570
|
+
organism = (
|
1571
|
+
valid_inspect_kwargs.get("organism", False)
|
1572
|
+
or bt.settings.organism.name
|
1573
|
+
)
|
1574
|
+
check_organism = f"fix organism '{organism}', "
|
1575
|
+
warning_message += f" → {check_organism}fix typos, remove non-existent values, or save terms via: {colors.cyan(non_validated_hint_print)}"
|
1565
1576
|
if self._subtype_query_set is not None:
|
1566
1577
|
warning_message += f"\n → a valid label for subtype '{self._subtype_str}' has to be one of {self._subtype_query_set.to_list('name')}"
|
1567
1578
|
logger.info(f'mapping "{self._key}" on {colors.italic(model_field)}')
|
lamindb/examples/__init__.py
CHANGED
@@ -17,7 +17,7 @@ def mini_immuno(
|
|
17
17
|
"""Return paths to the mini immuno dataset and its metadata as a Croissant file.
|
18
18
|
|
19
19
|
Args:
|
20
|
-
n_files: Number of files inside the croissant file.
|
20
|
+
n_files: Number of files inside the croissant file.
|
21
21
|
filepath_prefix: Move the dataset and references to it in a specific directory.
|
22
22
|
|
23
23
|
Example
|
@@ -63,8 +63,10 @@ def mini_immuno(
|
|
63
63
|
croissant_path = Path("mini_immuno.anndata.zarr_metadata.json")
|
64
64
|
with open(croissant_path, "w", encoding="utf-8") as f:
|
65
65
|
json.dump(data, f, indent=2)
|
66
|
+
|
66
67
|
result: list[Path] = [croissant_path, dataset1_path]
|
67
68
|
if n_files == 1:
|
68
69
|
return result
|
69
70
|
result.append(dataset2_path)
|
71
|
+
|
70
72
|
return result
|
@@ -0,0 +1,38 @@
|
|
1
|
+
"""Examples and utilities for Mlflow.
|
2
|
+
|
3
|
+
.. autosummary::
|
4
|
+
:toctree: .
|
5
|
+
|
6
|
+
save_mlflow_features
|
7
|
+
"""
|
8
|
+
|
9
|
+
import lamindb as ln
|
10
|
+
|
11
|
+
|
12
|
+
def save_mlflow_features():
|
13
|
+
"""Saves all MLflow experiment and run related features.
|
14
|
+
|
15
|
+
Saves the following features:
|
16
|
+
|
17
|
+
- mlflow_run_id
|
18
|
+
- mlflow_run_name
|
19
|
+
- mlflow_experiment_id
|
20
|
+
- mlflow_experiment_name
|
21
|
+
- mlflow_user_id
|
22
|
+
- mlflow_status
|
23
|
+
- mlflow_lifecycle_stage
|
24
|
+
- mlflow_artifact_uri
|
25
|
+
- mlflow_start_time
|
26
|
+
- mlflow_end_time
|
27
|
+
"""
|
28
|
+
mlflow_type = ln.Feature(name="MLflow", is_type=True).save()
|
29
|
+
ln.Feature(name="mlflow_run_id", dtype=str, type=mlflow_type).save()
|
30
|
+
ln.Feature(name="mlflow_run_name", dtype=str, type=mlflow_type).save()
|
31
|
+
ln.Feature(name="mlflow_experiment_id", dtype=str, type=mlflow_type).save()
|
32
|
+
ln.Feature(name="mlflow_experiment_name", dtype=str, type=mlflow_type).save()
|
33
|
+
ln.Feature(name="mlflow_user_id", dtype=str, type=mlflow_type).save()
|
34
|
+
ln.Feature(name="mlflow_status", dtype=str, type=mlflow_type).save()
|
35
|
+
ln.Feature(name="mlflow_lifecycle_stage", dtype=str, type=mlflow_type).save()
|
36
|
+
ln.Feature(name="mlflow_artifact_uri", dtype=str, type=mlflow_type).save()
|
37
|
+
ln.Feature(name="mlflow_start_time", dtype=int, type=mlflow_type).save()
|
38
|
+
ln.Feature(name="mlflow_end_time", dtype=int, type=mlflow_type).save()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""Examples and utilities for Weights & Biases.
|
2
|
+
|
3
|
+
.. autosummary::
|
4
|
+
:toctree: .
|
5
|
+
|
6
|
+
save_wandb_features
|
7
|
+
"""
|
8
|
+
|
9
|
+
import lamindb as ln
|
10
|
+
|
11
|
+
|
12
|
+
def save_wandb_features():
|
13
|
+
"""Saves all Weights & Biases project and run related features.
|
14
|
+
|
15
|
+
Saves the following features:
|
16
|
+
|
17
|
+
- wandb_run_id
|
18
|
+
- wandb_run_name
|
19
|
+
- wandb_run_entity
|
20
|
+
- wandb_project
|
21
|
+
- wandb_state
|
22
|
+
- wandb_url
|
23
|
+
- wandb_tags
|
24
|
+
- wandb_group
|
25
|
+
- wandb_job_type
|
26
|
+
- timestamp
|
27
|
+
- runtime
|
28
|
+
"""
|
29
|
+
wandb_type = ln.Feature(name="Weights & Biases", is_type=True).save()
|
30
|
+
ln.Feature(name="wandb_run_id", dtype=str, type=wandb_type).save()
|
31
|
+
ln.Feature(name="wandb_run_name", dtype=str, type=wandb_type).save()
|
32
|
+
ln.Feature(name="wandb_run_entity", dtype=str, type=wandb_type).save()
|
33
|
+
ln.Feature(name="wandb_project", dtype=str, type=wandb_type).save()
|
34
|
+
ln.Feature(name="wandb_state", dtype=str, type=wandb_type).save()
|
35
|
+
ln.Feature(name="wandb_url", dtype=str, type=wandb_type).save()
|
36
|
+
ln.Feature(name="wandb_tags", dtype=str, type=wandb_type).save()
|
37
|
+
ln.Feature(name="wandb_group", dtype=str, type=wandb_type).save()
|
38
|
+
ln.Feature(name="wandb_job_type", dtype=str, type=wandb_type).save()
|
39
|
+
ln.Feature(name="wandb_timestamp", dtype=float, type=wandb_type).save()
|
40
|
+
ln.Feature(name="wandb_runtime", dtype=float, type=wandb_type).save()
|
lamindb/integrations/__init__.py
CHANGED
@@ -6,9 +6,35 @@
|
|
6
6
|
save_vitessce_config
|
7
7
|
save_tiledbsoma_experiment
|
8
8
|
curate_from_croissant
|
9
|
+
lightning
|
9
10
|
"""
|
10
11
|
|
12
|
+
from typing import Any
|
13
|
+
|
14
|
+
|
15
|
+
def __getattr__(attr_name: str) -> Any:
|
16
|
+
# Defers import until accessed to avoid requiring PyTorch Lightning
|
17
|
+
if attr_name == "lightning":
|
18
|
+
from lamindb.integrations import lightning
|
19
|
+
|
20
|
+
return lightning
|
21
|
+
raise AttributeError(f"module has no attribute {attr_name!r}")
|
22
|
+
|
23
|
+
|
11
24
|
from lamindb.core.storage import save_tiledbsoma_experiment
|
12
25
|
|
13
26
|
from ._croissant import curate_from_croissant
|
14
27
|
from ._vitessce import save_vitessce_config
|
28
|
+
|
29
|
+
|
30
|
+
def __dir__():
|
31
|
+
# Makes lazy imports discoverable to dir() to enable autocomplete including lazy modules
|
32
|
+
return __all__
|
33
|
+
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
"lightning",
|
37
|
+
"save_tiledbsoma_experiment",
|
38
|
+
"curate_from_croissant",
|
39
|
+
"save_vitessce_config",
|
40
|
+
]
|
@@ -0,0 +1,87 @@
|
|
1
|
+
"""PyTorch Lightning integrations.
|
2
|
+
|
3
|
+
.. autosummary::
|
4
|
+
:toctree: .
|
5
|
+
|
6
|
+
Callback
|
7
|
+
"""
|
8
|
+
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
import lightning as pl
|
13
|
+
from lightning.pytorch import LightningModule, Trainer
|
14
|
+
|
15
|
+
import lamindb as ln
|
16
|
+
|
17
|
+
|
18
|
+
class Callback(pl.Callback):
|
19
|
+
"""Saves PyTorch Lightning model checkpoints to the LaminDB instance after each training epoch.
|
20
|
+
|
21
|
+
Creates version families of artifacts for given `key` (relative file path).
|
22
|
+
|
23
|
+
Args:
|
24
|
+
path: Path to the checkpoint
|
25
|
+
key: Artifact key
|
26
|
+
features: Additional feature values that every checkpoint gets annotated by.
|
27
|
+
|
28
|
+
Examples:
|
29
|
+
|
30
|
+
Create a callback which creates artifacts for checkpoints and annotates them by the MLflow run ID
|
31
|
+
|
32
|
+
lamindb_callback = ln.integrations.lightning.Callback(
|
33
|
+
path=checkpoint_filename, key=artifact_key, annotate_by={ "mlflow_run_id": mlflow_run.info.run_id }
|
34
|
+
)
|
35
|
+
trainer = pl.Trainer(
|
36
|
+
callbacks=[lamindb_callback]
|
37
|
+
)
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
path: str | Path,
|
43
|
+
key: str,
|
44
|
+
features: dict[str, Any] | None = None,
|
45
|
+
):
|
46
|
+
self.path = Path(path)
|
47
|
+
self.key = key
|
48
|
+
self.features = features or {}
|
49
|
+
|
50
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
51
|
+
"""Validates that features exist for all specified params."""
|
52
|
+
missing = [
|
53
|
+
feature
|
54
|
+
for feature in self.features.keys()
|
55
|
+
if ln.Feature.filter(name=feature).one_or_none() is None
|
56
|
+
]
|
57
|
+
if missing:
|
58
|
+
s = "s" if len(missing) > 1 else ""
|
59
|
+
raise ValueError(
|
60
|
+
f"Feature{s} {', '.join(missing)} missing. Create {'them' if len(missing) > 1 else 'it'} first."
|
61
|
+
)
|
62
|
+
|
63
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
64
|
+
"""Saves model checkpoint artifacts at the end of each epoch and optionally annotates them."""
|
65
|
+
trainer.save_checkpoint(self.path)
|
66
|
+
af = ln.Artifact(self.path, key=self.key, kind="model").save()
|
67
|
+
|
68
|
+
feature_values = dict(self.features)
|
69
|
+
|
70
|
+
for name in self.features.keys():
|
71
|
+
if hasattr(trainer, name):
|
72
|
+
feature_values[name] = getattr(trainer, name)
|
73
|
+
elif name in trainer.callback_metrics:
|
74
|
+
metric_value = trainer.callback_metrics[name]
|
75
|
+
feature_values[name] = (
|
76
|
+
metric_value.item()
|
77
|
+
if hasattr(metric_value, "item")
|
78
|
+
else float(metric_value)
|
79
|
+
)
|
80
|
+
|
81
|
+
if feature_values:
|
82
|
+
af.features.add_values(feature_values)
|
83
|
+
|
84
|
+
af.save()
|
85
|
+
|
86
|
+
|
87
|
+
__all__ = ["Callback"]
|