eiax 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eiax/__init__.py ADDED
@@ -0,0 +1,34 @@
1
+ """Public exports."""
2
+
3
+ from eiax.__version__ import __version__
4
+ from eiax.cache import CacheConfig
5
+ from eiax.catalog import facet_values, help_route, search
6
+ from eiax.client import EIAClient
7
+ from eiax.errors import (
8
+ AuthenticationError,
9
+ EIAError,
10
+ EmptyResultError,
11
+ RateLimitError,
12
+ UnknownSeriesError,
13
+ )
14
+ from eiax.fetch import fetch_async, fetch_sync
15
+ from eiax.series import to_wide
16
+
17
+ fetch = fetch_sync
18
+
19
+ __all__ = [
20
+ "__version__",
21
+ "AuthenticationError",
22
+ "CacheConfig",
23
+ "EIAClient",
24
+ "EIAError",
25
+ "EmptyResultError",
26
+ "RateLimitError",
27
+ "UnknownSeriesError",
28
+ "facet_values",
29
+ "fetch",
30
+ "fetch_async",
31
+ "help_route",
32
+ "search",
33
+ "to_wide",
34
+ ]
eiax/__version__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
eiax/_sync.py ADDED
@@ -0,0 +1,18 @@
1
+ """Run async coroutines from sync call sites (including Jupyter notebooks)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from collections.abc import Coroutine
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+
10
+ def run_sync[T](coro: Coroutine[object, object, T]) -> T:
11
+ """ponytail: thread + fresh loop when a loop is already running (Jupyter)."""
12
+ try:
13
+ asyncio.get_running_loop()
14
+ except RuntimeError:
15
+ return asyncio.run(coro)
16
+
17
+ with ThreadPoolExecutor(max_workers=1) as pool:
18
+ return pool.submit(asyncio.run, coro).result()
eiax/cache.py ADDED
@@ -0,0 +1,336 @@
1
+ """Parquet cache + sqlite manifest, gap detection, TTL."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import os
8
+ import sqlite3
9
+ from dataclasses import dataclass
10
+ from datetime import UTC, datetime, timedelta
11
+ from pathlib import Path
12
+ from typing import NamedTuple
13
+
14
+ import polars as pl
15
+
16
+ from eiax.settings import get_settings
17
+
18
+ _MANIFEST_SCHEMA = """
19
+ CREATE TABLE IF NOT EXISTS partitions (
20
+ route TEXT NOT NULL,
21
+ frequency TEXT NOT NULL,
22
+ facets_key TEXT NOT NULL,
23
+ covered_start TEXT NOT NULL,
24
+ covered_end TEXT NOT NULL,
25
+ row_count INTEGER NOT NULL,
26
+ written_at TEXT NOT NULL,
27
+ PRIMARY KEY (route, frequency, facets_key)
28
+ );
29
+ """
30
+
31
+
32
+ @dataclass
33
+ class CacheConfig:
34
+ enabled: bool = True
35
+ cache_dir: Path | None = None
36
+ recent_ttl_hours: int = 48
37
+
38
+ @classmethod
39
+ def from_settings(cls) -> CacheConfig:
40
+ cfg = get_settings()
41
+ return cls(
42
+ enabled=cfg.cache_enabled,
43
+ cache_dir=cfg.cache_dir,
44
+ recent_ttl_hours=cfg.cache_ttl_hours,
45
+ )
46
+
47
+ def resolve_dir(self) -> Path:
48
+ if self.cache_dir is not None:
49
+ return Path(self.cache_dir)
50
+ if cache_home := os.environ.get("XDG_CACHE_HOME"):
51
+ return Path(cache_home) / "eiax"
52
+ return Path.home() / ".cache" / "eiax"
53
+
54
+
55
+ class DateRange(NamedTuple):
56
+ start: str
57
+ end: str
58
+
59
+
60
+ def facets_key(facets: dict[str, str | list[str]]) -> str:
61
+ normalized = json.dumps(facets, sort_keys=True, default=list)
62
+ return hashlib.sha256(normalized.encode()).hexdigest()[:16]
63
+
64
+
65
+ def _connect(db_path: Path) -> sqlite3.Connection:
66
+ conn = sqlite3.connect(db_path)
67
+ conn.execute("PRAGMA journal_mode=WAL")
68
+ conn.executescript(_MANIFEST_SCHEMA)
69
+ return conn
70
+
71
+
72
+ def _partition_dir(cache_dir: Path, route: str, frequency: str, fkey: str) -> Path:
73
+ return cache_dir.joinpath(*route.strip("/").split("/"), frequency or "_", fkey)
74
+
75
+
76
+ def _parse_bound(value: str) -> datetime:
77
+ for fmt in (
78
+ "%Y-%m-%dT%H:%M",
79
+ "%Y-%m-%dT%H",
80
+ "%Y-%m-%d",
81
+ "%Y-%m",
82
+ "%Y",
83
+ ):
84
+ try:
85
+ dt = datetime.strptime(value, fmt)
86
+ return dt.replace(tzinfo=UTC)
87
+ except ValueError:
88
+ continue
89
+ msg = f"unsupported period bound: {value!r}"
90
+ raise ValueError(msg)
91
+
92
+
93
+ def _period_bounds(df: pl.DataFrame) -> tuple[str, str]:
94
+ periods = df.sort("period")["period"]
95
+ start = periods[0]
96
+ end = periods[-1]
97
+ return _period_to_cache_str(start), _period_to_cache_str(end)
98
+
99
+
100
+ def _period_to_cache_str(value: object) -> str:
101
+ if isinstance(value, datetime):
102
+ dt = value.astimezone(UTC)
103
+ if dt.hour or dt.minute:
104
+ if dt.minute:
105
+ return dt.strftime("%Y-%m-%dT%H:%M")
106
+ return dt.strftime("%Y-%m-%dT%H")
107
+ if dt.day == 1 and dt.month == 1 and dt.hour == 0:
108
+ return dt.strftime("%Y")
109
+ if dt.day == 1 and dt.hour == 0:
110
+ return dt.strftime("%Y-%m")
111
+ return dt.strftime("%Y-%m-%d")
112
+ return str(value)
113
+
114
+
115
+ def _merge_ranges(ranges: list[DateRange]) -> list[DateRange]:
116
+ if not ranges:
117
+ return []
118
+ ordered = sorted(ranges, key=lambda r: _parse_bound(r.start))
119
+ merged: list[DateRange] = [ordered[0]]
120
+ for current in ordered[1:]:
121
+ prev = merged[-1]
122
+ if _parse_bound(current.start) <= _parse_bound(prev.end):
123
+ later = (
124
+ prev.end
125
+ if _parse_bound(prev.end) >= _parse_bound(current.end)
126
+ else current.end
127
+ )
128
+ merged[-1] = DateRange(prev.start, later)
129
+ else:
130
+ merged.append(current)
131
+ return merged
132
+
133
+
134
+ def find_gaps(
135
+ row: tuple[str, str, str] | None,
136
+ *,
137
+ start: str,
138
+ end: str,
139
+ recent_ttl_hours: int,
140
+ now: datetime | None = None,
141
+ ) -> list[DateRange]:
142
+ """Return sub-ranges that still need a network fetch."""
143
+ if row is None:
144
+ return [DateRange(start, end)]
145
+
146
+ covered_start, covered_end, written_at_raw = row
147
+ start_dt = _parse_bound(start)
148
+ end_dt = _parse_bound(end)
149
+ cov_start_dt = _parse_bound(covered_start)
150
+ cov_end_dt = _parse_bound(covered_end)
151
+
152
+ gaps: list[DateRange] = []
153
+ if start_dt < cov_start_dt:
154
+ gaps.append(DateRange(start, covered_start))
155
+ if end_dt > cov_end_dt:
156
+ gaps.append(DateRange(covered_end, end))
157
+
158
+ now = now or datetime.now(UTC)
159
+ written_at = datetime.fromisoformat(written_at_raw)
160
+ if written_at.tzinfo is None:
161
+ written_at = written_at.replace(tzinfo=UTC)
162
+ if now - written_at > timedelta(hours=recent_ttl_hours):
163
+ cutoff_dt = now - timedelta(hours=recent_ttl_hours)
164
+ if cutoff_dt <= end_dt:
165
+ recent_start = (
166
+ start if start_dt >= cutoff_dt else cutoff_dt.strftime("%Y-%m-%dT%H")
167
+ )
168
+ gaps.append(DateRange(recent_start, end))
169
+
170
+ return _merge_ranges(gaps)
171
+
172
+
173
+ def _key_columns(df: pl.DataFrame) -> list[str]:
174
+ """Dedup key = period + facet-code columns.
175
+
176
+ Measures are Float64 after parsing, so excluding numeric columns drops them
177
+ regardless of their name; ``-name`` descriptive columns are functionally
178
+ dependent on their code column and are excluded too. This keeps "new rows
179
+ win" correct even for routes whose measure isn't literally called ``value``.
180
+ """
181
+ keys = ["period"]
182
+ keys += [
183
+ c
184
+ for c in df.columns
185
+ if c != "period" and df.schema[c] == pl.Utf8 and not c.endswith("-name")
186
+ ]
187
+ return keys
188
+
189
+
190
+ class CacheStore:
191
+ def __init__(self, config: CacheConfig) -> None:
192
+ self.config = config
193
+ self.cache_dir = config.resolve_dir()
194
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
195
+ self._manifest = self.cache_dir / "manifest.db"
196
+
197
+ def _manifest_row(
198
+ self, route: str, frequency: str, fkey: str
199
+ ) -> tuple[str, str, str, int] | None:
200
+ with _connect(self._manifest) as conn:
201
+ row = conn.execute(
202
+ """
203
+ SELECT covered_start, covered_end, written_at, row_count
204
+ FROM partitions
205
+ WHERE route = ? AND frequency = ? AND facets_key = ?
206
+ """,
207
+ (route, frequency or "", fkey),
208
+ ).fetchone()
209
+ if row is None:
210
+ return None
211
+ return row[0], row[1], row[2], row[3]
212
+
213
+ def gaps(
214
+ self,
215
+ route: str,
216
+ frequency: str | None,
217
+ facets: dict[str, str | list[str]],
218
+ start: str,
219
+ end: str,
220
+ ) -> list[DateRange]:
221
+ row = self._manifest_row(route, frequency or "", facets_key(facets))
222
+ if row is None:
223
+ return find_gaps(None, start=start, end=end, recent_ttl_hours=0)
224
+ return find_gaps(
225
+ (row[0], row[1], row[2]),
226
+ start=start,
227
+ end=end,
228
+ recent_ttl_hours=self.config.recent_ttl_hours,
229
+ )
230
+
231
+ def partition_path(
232
+ self, route: str, frequency: str | None, facets: dict[str, str | list[str]]
233
+ ) -> Path:
234
+ return _partition_dir(
235
+ self.cache_dir, route, frequency or "_", facets_key(facets)
236
+ )
237
+
238
+ def read_slice(
239
+ self,
240
+ route: str,
241
+ frequency: str | None,
242
+ facets: dict[str, str | list[str]],
243
+ start: str,
244
+ end: str,
245
+ ) -> pl.DataFrame | None:
246
+ path = self.partition_path(route, frequency, facets) / "data.parquet"
247
+ if not path.is_file():
248
+ return None
249
+ start_dt = _parse_bound(start)
250
+ end_dt = _parse_bound(end)
251
+ return (
252
+ pl.scan_parquet(path)
253
+ .filter(
254
+ pl.col("period") >= pl.lit(start_dt, dtype=pl.Datetime("us", "UTC")),
255
+ pl.col("period") <= pl.lit(end_dt, dtype=pl.Datetime("us", "UTC")),
256
+ )
257
+ .collect()
258
+ )
259
+
260
+ def merge_write(
261
+ self,
262
+ route: str,
263
+ frequency: str | None,
264
+ facets: dict[str, str | list[str]],
265
+ frame: pl.DataFrame,
266
+ ) -> None:
267
+ if frame.is_empty():
268
+ return
269
+ part_dir = self.partition_path(route, frequency, facets)
270
+ part_dir.mkdir(parents=True, exist_ok=True)
271
+ path = part_dir / "data.parquet"
272
+ tmp = part_dir / "data.parquet.tmp"
273
+
274
+ existing: pl.DataFrame | None = None
275
+ if path.is_file():
276
+ existing = pl.read_parquet(path)
277
+
278
+ keys = _key_columns(frame)
279
+ if existing is not None and not existing.is_empty():
280
+ merged = (
281
+ pl.concat([existing, frame], how="diagonal_relaxed")
282
+ .unique(subset=keys, keep="last")
283
+ .sort("period")
284
+ )
285
+ else:
286
+ merged = frame.unique(subset=keys, keep="last").sort("period")
287
+
288
+ merged.write_parquet(tmp)
289
+ os.replace(tmp, path)
290
+
291
+ covered_start, covered_end = _period_bounds(merged)
292
+ written_at = datetime.now(UTC).isoformat()
293
+ with _connect(self._manifest) as conn:
294
+ conn.execute(
295
+ """
296
+ INSERT INTO partitions (
297
+ route, frequency, facets_key,
298
+ covered_start, covered_end, row_count, written_at
299
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
300
+ ON CONFLICT(route, frequency, facets_key) DO UPDATE SET
301
+ covered_start = excluded.covered_start,
302
+ covered_end = excluded.covered_end,
303
+ row_count = excluded.row_count,
304
+ written_at = excluded.written_at
305
+ """,
306
+ (
307
+ route,
308
+ frequency or "",
309
+ facets_key(facets),
310
+ covered_start,
311
+ covered_end,
312
+ merged.height,
313
+ written_at,
314
+ ),
315
+ )
316
+ conn.commit()
317
+
318
+
319
+ if __name__ == "__main__":
320
+ full = find_gaps(None, start="2024-01-01", end="2024-03-01", recent_ttl_hours=48)
321
+ assert full == [DateRange("2024-01-01", "2024-03-01")]
322
+ row = ("2024-02-01", "2024-02-28", datetime.now(UTC).isoformat())
323
+ gaps = find_gaps(row, start="2024-01-01", end="2024-03-01", recent_ttl_hours=48)
324
+ assert DateRange("2024-01-01", "2024-02-01") in gaps
325
+ assert DateRange("2024-02-28", "2024-03-01") in gaps
326
+ stale = ("2024-01-01", "2024-03-01", "2020-01-01T00:00:00+00:00")
327
+ now = datetime(2024, 3, 15, tzinfo=UTC)
328
+ stale_gaps = find_gaps(
329
+ stale,
330
+ start="2024-01-01",
331
+ end="2024-03-15",
332
+ recent_ttl_hours=48,
333
+ now=now,
334
+ )
335
+ assert stale_gaps
336
+ print("cache self-check ok")