atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +31 -1
  3. atdata/_cid.py +29 -35
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +33 -17
  6. atdata/_hf_api.py +109 -59
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +74 -132
  9. atdata/_schema_codec.py +38 -41
  10. atdata/_sources.py +57 -64
  11. atdata/_stub_manager.py +31 -26
  12. atdata/_type_utils.py +47 -7
  13. atdata/atmosphere/__init__.py +31 -24
  14. atdata/atmosphere/_types.py +11 -11
  15. atdata/atmosphere/client.py +11 -8
  16. atdata/atmosphere/lens.py +27 -30
  17. atdata/atmosphere/records.py +34 -39
  18. atdata/atmosphere/schema.py +35 -31
  19. atdata/atmosphere/store.py +16 -20
  20. atdata/cli/__init__.py +163 -168
  21. atdata/cli/diagnose.py +12 -8
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/local.py +5 -2
  24. atdata/cli/preview.py +63 -0
  25. atdata/cli/schema.py +109 -0
  26. atdata/dataset.py +678 -533
  27. atdata/lens.py +85 -83
  28. atdata/local/__init__.py +71 -0
  29. atdata/local/_entry.py +157 -0
  30. atdata/local/_index.py +940 -0
  31. atdata/local/_repo_legacy.py +218 -0
  32. atdata/local/_s3.py +349 -0
  33. atdata/local/_schema.py +380 -0
  34. atdata/manifest/__init__.py +28 -0
  35. atdata/manifest/_aggregates.py +156 -0
  36. atdata/manifest/_builder.py +163 -0
  37. atdata/manifest/_fields.py +154 -0
  38. atdata/manifest/_manifest.py +146 -0
  39. atdata/manifest/_query.py +150 -0
  40. atdata/manifest/_writer.py +74 -0
  41. atdata/promote.py +20 -24
  42. atdata/providers/__init__.py +25 -0
  43. atdata/providers/_base.py +140 -0
  44. atdata/providers/_factory.py +69 -0
  45. atdata/providers/_postgres.py +214 -0
  46. atdata/providers/_redis.py +171 -0
  47. atdata/providers/_sqlite.py +191 -0
  48. atdata/repository.py +323 -0
  49. atdata/testing.py +337 -0
  50. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
  51. atdata-0.3.0b1.dist-info/RECORD +54 -0
  52. atdata/local.py +0 -1707
  53. atdata-0.2.2b1.dist-info/RECORD +0 -28
  54. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
  55. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
  56. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,380 @@
1
+ """Schema models and helper functions for local storage."""
2
+
3
+ from atdata._type_utils import (
4
+ PRIMITIVE_TYPE_MAP,
5
+ unwrap_optional,
6
+ is_ndarray_type,
7
+ extract_ndarray_dtype,
8
+ parse_semver,
9
+ )
10
+ from atdata._protocols import Packable
11
+
12
+ from dataclasses import dataclass, fields, is_dataclass
13
+ from datetime import datetime, timezone
14
+ from typing import (
15
+ Any,
16
+ Type,
17
+ TypeVar,
18
+ Iterator,
19
+ Optional,
20
+ Literal,
21
+ get_type_hints,
22
+ get_origin,
23
+ get_args,
24
+ )
25
+
26
+ T = TypeVar("T", bound=Packable)
27
+
28
+ # URI scheme prefixes
29
+ _ATDATA_URI_PREFIX = "atdata://local/sampleSchema/"
30
+ _LEGACY_URI_PREFIX = "local://schemas/"
31
+
32
+
33
+ class SchemaNamespace:
34
+ """Namespace for accessing loaded schema types as attributes.
35
+
36
+ After ``index.load_schema(uri)``, the type is available as an attribute.
37
+ Supports attribute access, iteration, ``len()``, and ``in`` checks.
38
+
39
+ Examples:
40
+ >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
41
+ >>> MyType = index.types.MySample
42
+ >>> sample = MyType(field1="hello", field2=42)
43
+
44
+ Note:
45
+ For full IDE autocomplete, enable ``auto_stubs=True`` and add
46
+ ``index.stub_dir`` to your IDE's extraPaths.
47
+ """
48
+
49
+ def __init__(self) -> None:
50
+ self._types: dict[str, Type[Packable]] = {}
51
+
52
+ def _register(self, name: str, cls: Type[Packable]) -> None:
53
+ """Register a schema type in the namespace."""
54
+ self._types[name] = cls
55
+
56
+ def __getattr__(self, name: str) -> Any:
57
+ # Returns Any to avoid IDE complaints about unknown attributes.
58
+ # For full IDE support, import from the generated module instead.
59
+ if name.startswith("_"):
60
+ raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
61
+ if name not in self._types:
62
+ raise AttributeError(
63
+ f"Schema '{name}' not loaded. "
64
+ f"Call index.load_schema() first to load the schema."
65
+ )
66
+ return self._types[name]
67
+
68
+ def __dir__(self) -> list[str]:
69
+ return list(self._types.keys()) + ["_types", "_register", "get"]
70
+
71
+ def __iter__(self) -> Iterator[str]:
72
+ return iter(self._types)
73
+
74
+ def __len__(self) -> int:
75
+ return len(self._types)
76
+
77
+ def __contains__(self, name: str) -> bool:
78
+ return name in self._types
79
+
80
+ def __repr__(self) -> str:
81
+ if not self._types:
82
+ return "SchemaNamespace(empty)"
83
+ names = ", ".join(sorted(self._types.keys()))
84
+ return f"SchemaNamespace({names})"
85
+
86
+ def get(self, name: str, default: T | None = None) -> Type[Packable] | T | None:
87
+ """Get a type by name, returning default if not found.
88
+
89
+ Args:
90
+ name: The schema class name to look up.
91
+ default: Value to return if not found (default: None).
92
+
93
+ Returns:
94
+ The schema class, or default if not loaded.
95
+ """
96
+ return self._types.get(name, default)
97
+
98
+
99
+ ##
100
+ # Schema types
101
+
102
+
103
+ @dataclass
104
+ class SchemaFieldType:
105
+ """Schema field type definition for local storage.
106
+
107
+ Represents a type in the schema type system, supporting primitives,
108
+ ndarrays, arrays, and references to other schemas.
109
+ """
110
+
111
+ kind: Literal["primitive", "ndarray", "ref", "array"]
112
+ """The category of type."""
113
+
114
+ primitive: Optional[str] = None
115
+ """For kind='primitive': one of 'str', 'int', 'float', 'bool', 'bytes'."""
116
+
117
+ dtype: Optional[str] = None
118
+ """For kind='ndarray': numpy dtype string (e.g., 'float32')."""
119
+
120
+ ref: Optional[str] = None
121
+ """For kind='ref': URI of referenced schema."""
122
+
123
+ items: Optional["SchemaFieldType"] = None
124
+ """For kind='array': type of array elements."""
125
+
126
+ @classmethod
127
+ def from_dict(cls, data: dict) -> "SchemaFieldType":
128
+ """Create from a dictionary (e.g., from Redis storage)."""
129
+ type_str = data.get("$type", "")
130
+ if "#" in type_str:
131
+ kind = type_str.split("#")[-1]
132
+ else:
133
+ kind = data.get("kind", "primitive")
134
+
135
+ items = None
136
+ if "items" in data and data["items"]:
137
+ items = cls.from_dict(data["items"])
138
+
139
+ return cls(
140
+ kind=kind, # type: ignore[arg-type]
141
+ primitive=data.get("primitive"),
142
+ dtype=data.get("dtype"),
143
+ ref=data.get("ref"),
144
+ items=items,
145
+ )
146
+
147
+ def to_dict(self) -> dict:
148
+ """Convert to dictionary for storage."""
149
+ result: dict[str, Any] = {"$type": f"local#{self.kind}"}
150
+ if self.kind == "primitive":
151
+ result["primitive"] = self.primitive
152
+ elif self.kind == "ndarray":
153
+ result["dtype"] = self.dtype
154
+ elif self.kind == "ref":
155
+ result["ref"] = self.ref
156
+ elif self.kind == "array" and self.items:
157
+ result["items"] = self.items.to_dict()
158
+ return result
159
+
160
+
161
+ @dataclass
162
+ class SchemaField:
163
+ """Schema field definition for local storage."""
164
+
165
+ name: str
166
+ """Field name."""
167
+
168
+ field_type: SchemaFieldType
169
+ """Type of this field."""
170
+
171
+ optional: bool = False
172
+ """Whether this field can be None."""
173
+
174
+ @classmethod
175
+ def from_dict(cls, data: dict) -> "SchemaField":
176
+ """Create from a dictionary."""
177
+ return cls(
178
+ name=data["name"],
179
+ field_type=SchemaFieldType.from_dict(data["fieldType"]),
180
+ optional=data.get("optional", False),
181
+ )
182
+
183
+ def to_dict(self) -> dict:
184
+ """Convert to dictionary for storage."""
185
+ return {
186
+ "name": self.name,
187
+ "fieldType": self.field_type.to_dict(),
188
+ "optional": self.optional,
189
+ }
190
+
191
+
192
+ @dataclass
193
+ class LocalSchemaRecord:
194
+ """Schema record for local storage.
195
+
196
+ Represents a PackableSample schema stored in the local index.
197
+ Aligns with the atmosphere SchemaRecord structure for seamless promotion.
198
+ """
199
+
200
+ name: str
201
+ """Schema name (typically the class name)."""
202
+
203
+ version: str
204
+ """Semantic version string (e.g., '1.0.0')."""
205
+
206
+ fields: list[SchemaField]
207
+ """List of field definitions."""
208
+
209
+ ref: str
210
+ """Schema reference URI (atdata://local/sampleSchema/{name}@{version})."""
211
+
212
+ description: Optional[str] = None
213
+ """Human-readable description."""
214
+
215
+ created_at: Optional[datetime] = None
216
+ """When this schema was published."""
217
+
218
+ @classmethod
219
+ def from_dict(cls, data: dict) -> "LocalSchemaRecord":
220
+ """Create from a dictionary (e.g., from Redis storage)."""
221
+ created_at = None
222
+ if "createdAt" in data:
223
+ try:
224
+ created_at = datetime.fromisoformat(data["createdAt"])
225
+ except (ValueError, TypeError):
226
+ created_at = None # Invalid datetime format, leave as None
227
+
228
+ return cls(
229
+ name=data["name"],
230
+ version=data["version"],
231
+ fields=[SchemaField.from_dict(f) for f in data.get("fields", [])],
232
+ ref=data.get("$ref", ""),
233
+ description=data.get("description"),
234
+ created_at=created_at,
235
+ )
236
+
237
+ def to_dict(self) -> dict:
238
+ """Convert to dictionary for storage."""
239
+ result: dict[str, Any] = {
240
+ "name": self.name,
241
+ "version": self.version,
242
+ "fields": [f.to_dict() for f in self.fields],
243
+ "$ref": self.ref,
244
+ }
245
+ if self.description:
246
+ result["description"] = self.description
247
+ if self.created_at:
248
+ result["createdAt"] = self.created_at.isoformat()
249
+ return result
250
+
251
+
252
+ ##
253
+ # Schema helpers
254
+
255
+
256
+ def _kind_str_for_sample_type(st: Type[Packable]) -> str:
257
+ """Return fully-qualified 'module.name' string for a sample type."""
258
+ return f"{st.__module__}.{st.__name__}"
259
+
260
+
261
+ def _schema_ref_from_type(sample_type: Type[Packable], version: str) -> str:
262
+ """Generate 'atdata://local/sampleSchema/{name}@{version}' reference."""
263
+ return _make_schema_ref(sample_type.__name__, version)
264
+
265
+
266
+ def _make_schema_ref(name: str, version: str) -> str:
267
+ """Generate schema reference URI from name and version."""
268
+ return f"{_ATDATA_URI_PREFIX}{name}@{version}"
269
+
270
+
271
+ def _parse_schema_ref(ref: str) -> tuple[str, str]:
272
+ """Parse schema reference into (name, version).
273
+
274
+ Supports both new format: 'atdata://local/sampleSchema/{name}@{version}'
275
+ and legacy format: 'local://schemas/{module.Class}@{version}'
276
+ """
277
+ if ref.startswith(_ATDATA_URI_PREFIX):
278
+ path = ref[len(_ATDATA_URI_PREFIX) :]
279
+ elif ref.startswith(_LEGACY_URI_PREFIX):
280
+ path = ref[len(_LEGACY_URI_PREFIX) :]
281
+ else:
282
+ raise ValueError(f"Invalid schema reference: {ref}")
283
+
284
+ if "@" not in path:
285
+ raise ValueError(f"Schema reference must include version (@version): {ref}")
286
+
287
+ name, version = path.rsplit("@", 1)
288
+ # For legacy format, extract just the class name from module.Class
289
+ if "." in name:
290
+ name = name.rsplit(".", 1)[1]
291
+ return name, version
292
+
293
+
294
+ def _increment_patch(version: str) -> str:
295
+ """Increment patch version: 1.0.0 -> 1.0.1"""
296
+ major, minor, patch = parse_semver(version)
297
+ return f"{major}.{minor}.{patch + 1}"
298
+
299
+
300
+ def _python_type_to_field_type(python_type: Any) -> dict:
301
+ """Convert Python type annotation to schema field type dict."""
302
+ if python_type in PRIMITIVE_TYPE_MAP:
303
+ return {
304
+ "$type": "local#primitive",
305
+ "primitive": PRIMITIVE_TYPE_MAP[python_type],
306
+ }
307
+
308
+ if is_ndarray_type(python_type):
309
+ return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)}
310
+
311
+ origin = get_origin(python_type)
312
+ if origin is list:
313
+ args = get_args(python_type)
314
+ items = (
315
+ _python_type_to_field_type(args[0])
316
+ if args
317
+ else {"$type": "local#primitive", "primitive": "str"}
318
+ )
319
+ return {"$type": "local#array", "items": items}
320
+
321
+ if is_dataclass(python_type):
322
+ raise TypeError(
323
+ f"Nested dataclass types not yet supported: {python_type.__name__}. "
324
+ "Publish nested types separately and use references."
325
+ )
326
+
327
+ raise TypeError(f"Unsupported type for schema field: {python_type}")
328
+
329
+
330
+ def _build_schema_record(
331
+ sample_type: Type[Packable],
332
+ *,
333
+ version: str,
334
+ description: str | None = None,
335
+ ) -> dict:
336
+ """Build a schema record dict from a PackableSample type.
337
+
338
+ Args:
339
+ sample_type: The PackableSample subclass to introspect.
340
+ version: Semantic version string.
341
+ description: Optional human-readable description. If None, uses the
342
+ class docstring.
343
+
344
+ Returns:
345
+ Schema record dict suitable for Redis storage.
346
+
347
+ Raises:
348
+ ValueError: If sample_type is not a dataclass.
349
+ TypeError: If a field type is not supported.
350
+ """
351
+ if not is_dataclass(sample_type):
352
+ raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
353
+
354
+ # Use docstring as fallback for description
355
+ if description is None:
356
+ description = sample_type.__doc__
357
+
358
+ field_defs = []
359
+ type_hints = get_type_hints(sample_type)
360
+
361
+ for f in fields(sample_type):
362
+ field_type = type_hints.get(f.name, f.type)
363
+ field_type, is_optional = unwrap_optional(field_type)
364
+ field_type_dict = _python_type_to_field_type(field_type)
365
+
366
+ field_defs.append(
367
+ {
368
+ "name": f.name,
369
+ "fieldType": field_type_dict,
370
+ "optional": is_optional,
371
+ }
372
+ )
373
+
374
+ return {
375
+ "name": sample_type.__name__,
376
+ "version": version,
377
+ "fields": field_defs,
378
+ "description": description,
379
+ "createdAt": datetime.now(timezone.utc).isoformat(),
380
+ }
@@ -0,0 +1,28 @@
1
+ """Per-shard manifest and query system.
2
+
3
+ Provides manifest generation during shard writes and efficient
4
+ query execution over large datasets without full scans.
5
+
6
+ Components:
7
+
8
+ - ``ManifestField``: Annotation marker for manifest-included fields
9
+ - ``ManifestBuilder``: Accumulates sample metadata during writes
10
+ - ``ShardManifest``: Loaded manifest representation
11
+ - ``ManifestWriter``: Serializes manifests to JSON + parquet
12
+ - ``QueryExecutor``: Two-phase query over manifest metadata
13
+ - ``SampleLocation``: Address of a sample within a shard
14
+ """
15
+
16
+ from ._fields import ManifestField as ManifestField
17
+ from ._fields import AggregateKind as AggregateKind
18
+ from ._fields import resolve_manifest_fields as resolve_manifest_fields
19
+ from ._aggregates import CategoricalAggregate as CategoricalAggregate
20
+ from ._aggregates import NumericAggregate as NumericAggregate
21
+ from ._aggregates import SetAggregate as SetAggregate
22
+ from ._aggregates import create_aggregate as create_aggregate
23
+ from ._builder import ManifestBuilder as ManifestBuilder
24
+ from ._manifest import ShardManifest as ShardManifest
25
+ from ._manifest import MANIFEST_FORMAT_VERSION as MANIFEST_FORMAT_VERSION
26
+ from ._writer import ManifestWriter as ManifestWriter
27
+ from ._query import QueryExecutor as QueryExecutor
28
+ from ._query import SampleLocation as SampleLocation
@@ -0,0 +1,156 @@
1
+ """Statistical aggregate collectors for manifest fields.
2
+
3
+ Each aggregate type tracks running statistics during shard writing and
4
+ produces a summary dict for inclusion in the manifest JSON header.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+
13
+ @dataclass
14
+ class CategoricalAggregate:
15
+ """Aggregate for categorical (string/enum) fields.
16
+
17
+ Tracks value counts and cardinality across all samples in a shard.
18
+
19
+ Examples:
20
+ >>> agg = CategoricalAggregate()
21
+ >>> agg.add("dog")
22
+ >>> agg.add("cat")
23
+ >>> agg.add("dog")
24
+ >>> agg.to_dict()
25
+ {'type': 'categorical', 'cardinality': 2, 'value_counts': {'dog': 2, 'cat': 1}}
26
+ """
27
+
28
+ value_counts: dict[str, int] = field(default_factory=dict)
29
+
30
+ @property
31
+ def cardinality(self) -> int:
32
+ """Number of distinct values observed."""
33
+ return len(self.value_counts)
34
+
35
+ def add(self, value: Any) -> None:
36
+ """Record a value observation."""
37
+ key = str(value)
38
+ self.value_counts[key] = self.value_counts.get(key, 0) + 1
39
+
40
+ def to_dict(self) -> dict[str, Any]:
41
+ """Serialize to a JSON-compatible dict."""
42
+ return {
43
+ "type": "categorical",
44
+ "cardinality": self.cardinality,
45
+ "value_counts": dict(self.value_counts),
46
+ }
47
+
48
+
49
+ @dataclass
50
+ class NumericAggregate:
51
+ """Aggregate for numeric (int/float) fields.
52
+
53
+ Tracks min, max, sum, and count for computing summary statistics.
54
+
55
+ Examples:
56
+ >>> agg = NumericAggregate()
57
+ >>> agg.add(1.0)
58
+ >>> agg.add(3.0)
59
+ >>> agg.add(2.0)
60
+ >>> agg.to_dict()
61
+ {'type': 'numeric', 'min': 1.0, 'max': 3.0, 'mean': 2.0, 'count': 3}
62
+ """
63
+
64
+ _min: float = field(default=float("inf"))
65
+ _max: float = field(default=float("-inf"))
66
+ _sum: float = 0.0
67
+ count: int = 0
68
+
69
+ @property
70
+ def min(self) -> float:
71
+ """Minimum observed value."""
72
+ return self._min
73
+
74
+ @property
75
+ def max(self) -> float:
76
+ """Maximum observed value."""
77
+ return self._max
78
+
79
+ @property
80
+ def mean(self) -> float:
81
+ """Running mean of observed values."""
82
+ if self.count == 0:
83
+ return 0.0
84
+ return self._sum / self.count
85
+
86
+ def add(self, value: float | int) -> None:
87
+ """Record a numeric observation."""
88
+ v = float(value)
89
+ if v < self._min:
90
+ self._min = v
91
+ if v > self._max:
92
+ self._max = v
93
+ self._sum += v
94
+ self.count += 1
95
+
96
+ def to_dict(self) -> dict[str, Any]:
97
+ """Serialize to a JSON-compatible dict."""
98
+ return {
99
+ "type": "numeric",
100
+ "min": self._min,
101
+ "max": self._max,
102
+ "mean": self.mean,
103
+ "count": self.count,
104
+ }
105
+
106
+
107
+ @dataclass
108
+ class SetAggregate:
109
+ """Aggregate for set/tag (list) fields.
110
+
111
+ Tracks the union of all observed values across samples.
112
+
113
+ Examples:
114
+ >>> agg = SetAggregate()
115
+ >>> agg.add(["outdoor", "day"])
116
+ >>> agg.add(["indoor"])
117
+ >>> agg.to_dict()
118
+ {'type': 'set', 'all_values': ['day', 'indoor', 'outdoor']}
119
+ """
120
+
121
+ all_values: set[str] = field(default_factory=set)
122
+
123
+ def add(self, values: list | set | tuple) -> None:
124
+ """Record a collection of values."""
125
+ for v in values:
126
+ self.all_values.add(str(v))
127
+
128
+ def to_dict(self) -> dict[str, Any]:
129
+ """Serialize to a JSON-compatible dict."""
130
+ return {
131
+ "type": "set",
132
+ "all_values": sorted(self.all_values),
133
+ }
134
+
135
+
136
+ def create_aggregate(
137
+ kind: str,
138
+ ) -> CategoricalAggregate | NumericAggregate | SetAggregate:
139
+ """Create an aggregate collector for the given kind.
140
+
141
+ Args:
142
+ kind: One of ``"categorical"``, ``"numeric"``, ``"set"``.
143
+
144
+ Returns:
145
+ A new aggregate collector instance.
146
+
147
+ Raises:
148
+ ValueError: If kind is not recognized.
149
+ """
150
+ if kind == "categorical":
151
+ return CategoricalAggregate()
152
+ if kind == "numeric":
153
+ return NumericAggregate()
154
+ if kind == "set":
155
+ return SetAggregate()
156
+ raise ValueError(f"Unknown aggregate kind: {kind!r}")