atdata 0.3.0b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. atdata/__init__.py +11 -0
  2. atdata/_cid.py +0 -21
  3. atdata/_helpers.py +12 -0
  4. atdata/_hf_api.py +46 -1
  5. atdata/_logging.py +43 -0
  6. atdata/_protocols.py +81 -182
  7. atdata/_schema_codec.py +2 -2
  8. atdata/_sources.py +24 -4
  9. atdata/_stub_manager.py +5 -25
  10. atdata/atmosphere/__init__.py +60 -21
  11. atdata/atmosphere/_lexicon_types.py +595 -0
  12. atdata/atmosphere/_types.py +73 -245
  13. atdata/atmosphere/client.py +64 -12
  14. atdata/atmosphere/lens.py +60 -53
  15. atdata/atmosphere/records.py +291 -100
  16. atdata/atmosphere/schema.py +91 -65
  17. atdata/atmosphere/store.py +68 -66
  18. atdata/cli/__init__.py +16 -16
  19. atdata/cli/diagnose.py +2 -2
  20. atdata/cli/{local.py → infra.py} +10 -10
  21. atdata/dataset.py +266 -47
  22. atdata/index/__init__.py +54 -0
  23. atdata/{local → index}/_entry.py +6 -2
  24. atdata/{local → index}/_index.py +617 -72
  25. atdata/{local → index}/_schema.py +5 -5
  26. atdata/lexicons/__init__.py +127 -0
  27. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  28. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  29. atdata/lexicons/ac.foundation.dataset.lens.json +101 -0
  30. atdata/lexicons/ac.foundation.dataset.record.json +117 -0
  31. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  32. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +46 -0
  34. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  35. atdata/lexicons/ac.foundation.dataset.storageHttp.json +45 -0
  36. atdata/lexicons/ac.foundation.dataset.storageS3.json +61 -0
  37. atdata/lexicons/ndarray_shim.json +16 -0
  38. atdata/local/__init__.py +12 -13
  39. atdata/local/_repo_legacy.py +3 -3
  40. atdata/manifest/__init__.py +4 -0
  41. atdata/manifest/_proxy.py +321 -0
  42. atdata/promote.py +14 -10
  43. atdata/repository.py +66 -16
  44. atdata/stores/__init__.py +23 -0
  45. atdata/stores/_disk.py +131 -0
  46. atdata/{local → stores}/_s3.py +134 -112
  47. atdata/testing.py +12 -8
  48. {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/METADATA +2 -2
  49. atdata-0.3.2b1.dist-info/RECORD +71 -0
  50. atdata-0.3.0b1.dist-info/RECORD +0 -54
  51. {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/WHEEL +0 -0
  52. {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/entry_points.txt +0 -0
  53. {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py CHANGED
@@ -47,8 +47,6 @@ from ._protocols import DataSource, Packable
47
47
  from ._exceptions import SampleKeyError, PartialFailureError
48
48
 
49
49
  import numpy as np
50
- import pandas as pd
51
- import requests
52
50
 
53
51
  import typing
54
52
  from typing import (
@@ -70,6 +68,9 @@ from typing import (
70
68
  )
71
69
 
72
70
  if TYPE_CHECKING:
71
+ import pandas
72
+ import pandas as pd
73
+ from .manifest._proxy import Predicate
73
74
  from .manifest._query import SampleLocation
74
75
  from numpy.typing import NDArray
75
76
 
@@ -99,9 +100,11 @@ DT = TypeVar("DT")
99
100
 
100
101
 
101
102
  def _make_packable(x):
102
- """Convert numpy arrays to bytes; pass through other values unchanged."""
103
+ """Convert numpy arrays to bytes; coerce numpy scalars to Python natives."""
103
104
  if isinstance(x, np.ndarray):
104
105
  return eh.array_to_bytes(x)
106
+ if isinstance(x, np.generic):
107
+ return x.item()
105
108
  return x
106
109
 
107
110
 
@@ -280,16 +283,9 @@ class PackableSample(ABC):
280
283
 
281
284
  @property
282
285
  def packed(self) -> bytes:
283
- """Serialize to msgpack bytes. NDArray fields are auto-converted.
284
-
285
- Raises:
286
- RuntimeError: If msgpack serialization fails.
287
- """
286
+ """Serialize to msgpack bytes. NDArray fields are auto-converted."""
288
287
  o = {k: _make_packable(v) for k, v in vars(self).items()}
289
- ret = msgpack.packb(o)
290
- if ret is None:
291
- raise RuntimeError(f"Failed to pack sample to bytes: {o}")
292
- return ret
288
+ return msgpack.packb(o)
293
289
 
294
290
  @property
295
291
  def as_wds(self) -> WDSRawSample:
@@ -305,7 +301,7 @@ def _batch_aggregate(xs: Sequence):
305
301
  if not xs:
306
302
  return []
307
303
  if isinstance(xs[0], np.ndarray):
308
- return np.array(list(xs))
304
+ return np.stack(xs)
309
305
  return list(xs)
310
306
 
311
307
 
@@ -540,6 +536,8 @@ class Dataset(Generic[ST]):
540
536
  return None
541
537
 
542
538
  if self._metadata is None:
539
+ import requests
540
+
543
541
  with requests.get(self.metadata_url, stream=True) as response:
544
542
  response.raise_for_status()
545
543
  self._metadata = msgpack.unpackb(response.content, raw=False)
@@ -708,6 +706,8 @@ class Dataset(Generic[ST]):
708
706
  fn: Callable[[list[ST]], Any],
709
707
  *,
710
708
  shards: list[str] | None = None,
709
+ checkpoint: Path | str | None = None,
710
+ on_shard_error: Callable[[str, Exception], None] | None = None,
711
711
  ) -> dict[str, Any]:
712
712
  """Process each shard independently, collecting per-shard results.
713
713
 
@@ -723,6 +723,14 @@ class Dataset(Generic[ST]):
723
723
  shards: Optional list of shard identifiers to process. If ``None``,
724
724
  processes all shards in the dataset. Useful for retrying only
725
725
  the failed shards from a previous ``PartialFailureError``.
726
+ checkpoint: Optional path to a checkpoint file. If provided,
727
+ already-succeeded shard IDs are loaded from this file and
728
+ skipped. Each newly succeeded shard is appended. On full
729
+ success the file is deleted. On partial failure it remains
730
+ for resume.
731
+ on_shard_error: Optional callback invoked as
732
+ ``on_shard_error(shard_id, exception)`` for each failed shard,
733
+ enabling dead-letter logging or alerting.
726
734
 
727
735
  Returns:
728
736
  Dict mapping shard identifier to *fn*'s return value for each shard.
@@ -739,45 +747,67 @@ class Dataset(Generic[ST]):
739
747
  ... results = ds.process_shards(expensive_fn)
740
748
  ... except PartialFailureError as e:
741
749
  ... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
750
+
751
+ >>> # With checkpoint for crash recovery:
752
+ >>> results = ds.process_shards(expensive_fn, checkpoint="progress.txt")
742
753
  """
743
- from ._logging import get_logger
754
+ from ._logging import get_logger, log_operation
744
755
 
745
756
  log = get_logger()
746
757
  shard_ids = shards or self.list_shards()
747
- log.info("process_shards: starting %d shards", len(shard_ids))
758
+
759
+ # Load checkpoint: skip already-succeeded shards
760
+ checkpoint_path: Path | None = None
761
+ if checkpoint is not None:
762
+ checkpoint_path = Path(checkpoint)
763
+ if checkpoint_path.exists():
764
+ already_done = set(checkpoint_path.read_text().splitlines())
765
+ log.info(
766
+ "process_shards: loaded checkpoint, %d shards already done",
767
+ len(already_done),
768
+ )
769
+ shard_ids = [s for s in shard_ids if s not in already_done]
770
+ if not shard_ids:
771
+ log.info("process_shards: all shards already checkpointed")
772
+ return {}
748
773
 
749
774
  succeeded: list[str] = []
750
775
  failed: list[str] = []
751
776
  errors: dict[str, Exception] = {}
752
777
  results: dict[str, Any] = {}
753
778
 
754
- for shard_id in shard_ids:
755
- try:
756
- shard_ds = Dataset[self.sample_type](shard_id)
757
- shard_ds._sample_type_cache = self._sample_type_cache
758
- samples = list(shard_ds.ordered())
759
- results[shard_id] = fn(samples)
760
- succeeded.append(shard_id)
761
- log.debug("process_shards: shard ok %s", shard_id)
762
- except Exception as exc:
763
- failed.append(shard_id)
764
- errors[shard_id] = exc
765
- log.warning("process_shards: shard failed %s: %s", shard_id, exc)
766
-
767
- if failed:
768
- log.error(
769
- "process_shards: %d/%d shards failed",
770
- len(failed),
771
- len(shard_ids),
772
- )
773
- raise PartialFailureError(
774
- succeeded_shards=succeeded,
775
- failed_shards=failed,
776
- errors=errors,
777
- results=results,
778
- )
779
+ with log_operation("process_shards", total_shards=len(shard_ids)):
780
+ for shard_id in shard_ids:
781
+ try:
782
+ shard_ds = Dataset[self.sample_type](shard_id)
783
+ shard_ds._sample_type_cache = self._sample_type_cache
784
+ samples = list(shard_ds.ordered())
785
+ results[shard_id] = fn(samples)
786
+ succeeded.append(shard_id)
787
+ log.debug("process_shards: shard ok %s", shard_id)
788
+ if checkpoint_path is not None:
789
+ with open(checkpoint_path, "a") as f:
790
+ f.write(shard_id + "\n")
791
+ except Exception as exc:
792
+ failed.append(shard_id)
793
+ errors[shard_id] = exc
794
+ log.warning("process_shards: shard failed %s: %s", shard_id, exc)
795
+ if on_shard_error is not None:
796
+ on_shard_error(shard_id, exc)
797
+
798
+ if failed:
799
+ raise PartialFailureError(
800
+ succeeded_shards=succeeded,
801
+ failed_shards=failed,
802
+ errors=errors,
803
+ results=results,
804
+ )
805
+
806
+ # All shards succeeded; clean up checkpoint file
807
+ if checkpoint_path is not None and checkpoint_path.exists():
808
+ checkpoint_path.unlink()
809
+ log.debug("process_shards: checkpoint file removed (all shards done)")
779
810
 
780
- log.info("process_shards: all %d shards succeeded", len(shard_ids))
781
811
  return results
782
812
 
783
813
  def select(self, indices: Sequence[int]) -> list[ST]:
@@ -809,9 +839,25 @@ class Dataset(Generic[ST]):
809
839
  break
810
840
  return [result[i] for i in indices if i in result]
811
841
 
842
+ @property
843
+ def fields(self) -> "Any":
844
+ """Typed field proxy for manifest queries on this dataset.
845
+
846
+ Returns an object whose attributes are ``FieldProxy`` instances,
847
+ one per manifest-eligible field of this dataset's sample type.
848
+
849
+ Examples:
850
+ >>> ds = atdata.Dataset[MySample](url)
851
+ >>> Q = ds.fields
852
+ >>> results = ds.query(where=(Q.confidence > 0.9))
853
+ """
854
+ from .manifest._proxy import query_fields
855
+
856
+ return query_fields(self.sample_type)
857
+
812
858
  def query(
813
859
  self,
814
- where: "Callable[[pd.DataFrame], pd.Series]",
860
+ where: "Callable[[pd.DataFrame], pd.Series] | Predicate",
815
861
  ) -> "list[SampleLocation]":
816
862
  """Query this dataset using per-shard manifest metadata.
817
863
 
@@ -820,10 +866,12 @@ class Dataset(Generic[ST]):
820
866
  and executes a two-phase query (shard-level aggregate pruning,
821
867
  then sample-level parquet filtering).
822
868
 
869
+ The *where* argument accepts either a lambda/function that operates
870
+ on a pandas DataFrame, or a ``Predicate`` built from the proxy DSL.
871
+
823
872
  Args:
824
- where: Predicate function that receives a pandas DataFrame
825
- of manifest fields and returns a boolean Series selecting
826
- matching rows.
873
+ where: Predicate function or ``Predicate`` object that selects
874
+ matching rows from the per-sample manifest DataFrame.
827
875
 
828
876
  Returns:
829
877
  List of ``SampleLocation`` for matching samples.
@@ -835,6 +883,9 @@ class Dataset(Generic[ST]):
835
883
  >>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
836
884
  >>> len(locs)
837
885
  42
886
+
887
+ >>> Q = ds.fields
888
+ >>> locs = ds.query(where=(Q.confidence > 0.9))
838
889
  """
839
890
  from .manifest import QueryExecutor
840
891
 
@@ -842,7 +893,7 @@ class Dataset(Generic[ST]):
842
893
  executor = QueryExecutor.from_shard_urls(shard_urls)
843
894
  return executor.query(where=where)
844
895
 
845
- def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
896
+ def to_pandas(self, limit: int | None = None) -> "pandas.DataFrame":
846
897
  """Materialize the dataset (or first *limit* samples) as a DataFrame.
847
898
 
848
899
  Args:
@@ -865,6 +916,8 @@ class Dataset(Generic[ST]):
865
916
  rows = [
866
917
  asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
867
918
  ]
919
+ import pandas as pd
920
+
868
921
  return pd.DataFrame(rows)
869
922
 
870
923
  def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
@@ -1059,6 +1112,8 @@ class Dataset(Generic[ST]):
1059
1112
  Examples:
1060
1113
  >>> ds.to_parquet("output.parquet", maxcount=50000)
1061
1114
  """
1115
+ import pandas as pd
1116
+
1062
1117
  path = Path(path)
1063
1118
  if sample_map is None:
1064
1119
  sample_map = asdict
@@ -1127,7 +1182,7 @@ _T = TypeVar("_T")
1127
1182
 
1128
1183
 
1129
1184
  @dataclass_transform()
1130
- def packable(cls: type[_T]) -> type[Packable]:
1185
+ def packable(cls: type[_T]) -> type[_T]:
1131
1186
  """Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
1132
1187
 
1133
1188
  The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
@@ -1188,3 +1243,167 @@ def packable(cls: type[_T]) -> type[Packable]:
1188
1243
  ##
1189
1244
 
1190
1245
  return as_packable
1246
+
1247
+
1248
+ # ---------------------------------------------------------------------------
1249
+ # write_samples — convenience function for writing samples to tar files
1250
+ # ---------------------------------------------------------------------------
1251
+
1252
+
1253
+ def write_samples(
1254
+ samples: Iterable[ST],
1255
+ path: str | Path,
1256
+ *,
1257
+ maxcount: int | None = None,
1258
+ maxsize: int | None = None,
1259
+ manifest: bool = False,
1260
+ ) -> "Dataset[ST]":
1261
+ """Write an iterable of samples to WebDataset tar file(s).
1262
+
1263
+ Args:
1264
+ samples: Iterable of ``PackableSample`` instances. Must be non-empty.
1265
+ path: Output path for the tar file. For sharded output (when
1266
+ *maxcount* or *maxsize* is set), a ``%06d`` pattern is
1267
+ auto-appended if the path does not already contain ``%``.
1268
+ maxcount: Maximum samples per shard. Triggers multi-shard output.
1269
+ maxsize: Maximum bytes per shard. Triggers multi-shard output.
1270
+ manifest: If True, write per-shard manifest sidecar files
1271
+ (``.manifest.json`` + ``.manifest.parquet``) alongside each
1272
+ tar file. Manifests enable metadata queries via
1273
+ ``QueryExecutor`` without opening the tars.
1274
+
1275
+ Returns:
1276
+ A ``Dataset`` wrapping the written file(s), typed to the sample
1277
+ type of the input samples.
1278
+
1279
+ Raises:
1280
+ ValueError: If *samples* is empty.
1281
+
1282
+ Examples:
1283
+ >>> samples = [MySample(key="0", text="hello")]
1284
+ >>> ds = write_samples(samples, "out.tar")
1285
+ >>> list(ds.ordered())
1286
+ [MySample(key='0', text='hello')]
1287
+ """
1288
+ from ._hf_api import _shards_to_wds_url
1289
+ from ._logging import get_logger, log_operation
1290
+
1291
+ if manifest:
1292
+ from .manifest._builder import ManifestBuilder
1293
+ from .manifest._writer import ManifestWriter
1294
+
1295
+ log = get_logger()
1296
+ path = Path(path)
1297
+ path.parent.mkdir(parents=True, exist_ok=True)
1298
+
1299
+ use_shard_writer = maxcount is not None or maxsize is not None
1300
+ sample_type: type | None = None
1301
+ written_paths: list[str] = []
1302
+
1303
+ with log_operation(
1304
+ "write_samples", path=str(path), sharded=use_shard_writer, manifest=manifest
1305
+ ):
1306
+ # Manifest tracking state
1307
+ _current_builder: list = [] # single-element list for nonlocal mutation
1308
+ _builders: list[tuple[str, "ManifestBuilder"]] = []
1309
+ _running_offset: list[int] = [0]
1310
+
1311
+ def _finalize_builder() -> None:
1312
+ """Finalize the current manifest builder and stash it."""
1313
+ if _current_builder:
1314
+ shard_path = written_paths[-1] if written_paths else ""
1315
+ _builders.append((shard_path, _current_builder[0]))
1316
+ _current_builder.clear()
1317
+
1318
+ def _start_builder(shard_path: str) -> None:
1319
+ """Start a new manifest builder for a shard."""
1320
+ _finalize_builder()
1321
+ shard_id = Path(shard_path).stem
1322
+ _current_builder.append(
1323
+ ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
1324
+ )
1325
+ _running_offset[0] = 0
1326
+
1327
+ def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
1328
+ """Record a sample in the active manifest builder."""
1329
+ if not _current_builder:
1330
+ return
1331
+ packed_bytes = wds_dict["msgpack"]
1332
+ size = len(packed_bytes)
1333
+ _current_builder[0].add_sample(
1334
+ key=wds_dict["__key__"],
1335
+ offset=_running_offset[0],
1336
+ size=size,
1337
+ sample=sample,
1338
+ )
1339
+ _running_offset[0] += size
1340
+
1341
+ if use_shard_writer:
1342
+ # Build shard pattern from path
1343
+ if "%" not in str(path):
1344
+ pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
1345
+ else:
1346
+ pattern = str(path)
1347
+
1348
+ writer_kwargs: dict[str, Any] = {}
1349
+ if maxcount is not None:
1350
+ writer_kwargs["maxcount"] = maxcount
1351
+ if maxsize is not None:
1352
+ writer_kwargs["maxsize"] = maxsize
1353
+
1354
+ def _track(p: str) -> None:
1355
+ written_paths.append(str(Path(p).resolve()))
1356
+ if manifest and sample_type is not None:
1357
+ _start_builder(p)
1358
+
1359
+ with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
1360
+ for sample in samples:
1361
+ if sample_type is None:
1362
+ sample_type = type(sample)
1363
+ wds_dict = sample.as_wds
1364
+ sink.write(wds_dict)
1365
+ if manifest:
1366
+ # The first sample triggers _track before we get here when
1367
+ # ShardWriter opens the first shard, but just in case:
1368
+ if not _current_builder and sample_type is not None:
1369
+ _start_builder(str(path))
1370
+ _record_sample(sample, wds_dict)
1371
+ else:
1372
+ with wds.writer.TarWriter(str(path)) as sink:
1373
+ for sample in samples:
1374
+ if sample_type is None:
1375
+ sample_type = type(sample)
1376
+ wds_dict = sample.as_wds
1377
+ sink.write(wds_dict)
1378
+ if manifest:
1379
+ if not _current_builder and sample_type is not None:
1380
+ _current_builder.append(
1381
+ ManifestBuilder(
1382
+ sample_type=sample_type, shard_id=path.stem
1383
+ )
1384
+ )
1385
+ _record_sample(sample, wds_dict)
1386
+ written_paths.append(str(path.resolve()))
1387
+
1388
+ if sample_type is None:
1389
+ raise ValueError("samples must be non-empty")
1390
+
1391
+ # Finalize and write manifests
1392
+ if manifest:
1393
+ _finalize_builder()
1394
+ for shard_path, builder in _builders:
1395
+ m = builder.build()
1396
+ base = str(Path(shard_path).with_suffix(""))
1397
+ writer = ManifestWriter(base)
1398
+ writer.write(m)
1399
+
1400
+ log.info(
1401
+ "write_samples: wrote %d shard(s), sample_type=%s",
1402
+ len(written_paths),
1403
+ sample_type.__name__,
1404
+ )
1405
+
1406
+ url = _shards_to_wds_url(written_paths)
1407
+ ds: Dataset = Dataset(url)
1408
+ ds._sample_type_cache = sample_type
1409
+ return ds
@@ -0,0 +1,54 @@
1
+ """Index and entry models for atdata datasets.
2
+
3
+ Key classes:
4
+
5
+ - ``Index``: Unified index with pluggable providers (SQLite default),
6
+ named repositories, and optional atmosphere backend.
7
+ - ``LocalDatasetEntry``: Index entry with ATProto-compatible CIDs.
8
+ """
9
+
10
+ from atdata.index._entry import (
11
+ LocalDatasetEntry,
12
+ BasicIndexEntry,
13
+ REDIS_KEY_DATASET_ENTRY,
14
+ REDIS_KEY_SCHEMA,
15
+ )
16
+ from atdata.index._schema import (
17
+ SchemaNamespace,
18
+ SchemaFieldType,
19
+ SchemaField,
20
+ LocalSchemaRecord,
21
+ _ATDATA_URI_PREFIX,
22
+ _LEGACY_URI_PREFIX,
23
+ _kind_str_for_sample_type,
24
+ _schema_ref_from_type,
25
+ _make_schema_ref,
26
+ _parse_schema_ref,
27
+ _increment_patch,
28
+ _python_type_to_field_type,
29
+ _build_schema_record,
30
+ )
31
+ from atdata.index._index import Index
32
+
33
+ __all__ = [
34
+ # Public API
35
+ "Index",
36
+ "LocalDatasetEntry",
37
+ "BasicIndexEntry",
38
+ "SchemaNamespace",
39
+ "SchemaFieldType",
40
+ "SchemaField",
41
+ "LocalSchemaRecord",
42
+ "REDIS_KEY_DATASET_ENTRY",
43
+ "REDIS_KEY_SCHEMA",
44
+ # Internal helpers (re-exported for backward compatibility)
45
+ "_ATDATA_URI_PREFIX",
46
+ "_LEGACY_URI_PREFIX",
47
+ "_kind_str_for_sample_type",
48
+ "_schema_ref_from_type",
49
+ "_make_schema_ref",
50
+ "_parse_schema_ref",
51
+ "_increment_patch",
52
+ "_python_type_to_field_type",
53
+ "_build_schema_record",
54
+ ]
@@ -1,12 +1,16 @@
1
1
  """Dataset entry model and Redis key constants."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from atdata._cid import generate_cid
4
6
 
5
7
  from dataclasses import dataclass, field
6
- from typing import Any, cast
8
+ from typing import Any, TYPE_CHECKING, cast
7
9
 
8
10
  import msgpack
9
- from redis import Redis
11
+
12
+ if TYPE_CHECKING:
13
+ from redis import Redis
10
14
 
11
15
 
12
16
  # Redis key prefixes for index entries and schemas