bagelquant-data 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bagelquant_data/__init__.py +32 -0
- bagelquant_data/cli/main.py +30 -0
- bagelquant_data/core/__init__.py +37 -0
- bagelquant_data/core/dataset.py +170 -0
- bagelquant_data/core/deduplication.py +40 -0
- bagelquant_data/core/exceptions.py +39 -0
- bagelquant_data/core/hashing.py +32 -0
- bagelquant_data/core/normalization.py +67 -0
- bagelquant_data/core/partitioning.py +76 -0
- bagelquant_data/core/registry.py +67 -0
- bagelquant_data/core/request.py +21 -0
- bagelquant_data/core/source.py +38 -0
- bagelquant_data/core/types.py +10 -0
- bagelquant_data/core/validation.py +32 -0
- bagelquant_data/finance/__init__.py +80 -0
- bagelquant_data/finance/align.py +5 -0
- bagelquant_data/finance/fields.py +30 -0
- bagelquant_data/finance/flows.py +24 -0
- bagelquant_data/finance/periods.py +11 -0
- bagelquant_data/finance/point_in_time.py +27 -0
- bagelquant_data/finance/ratios.py +31 -0
- bagelquant_data/finance/rolling.py +34 -0
- bagelquant_data/finance/shares.py +28 -0
- bagelquant_data/finance/stocks.py +27 -0
- bagelquant_data/management/__init__.py +5 -0
- bagelquant_data/management/datasets.py +63 -0
- bagelquant_data/management/lake.py +190 -0
- bagelquant_data/management/sources.py +57 -0
- bagelquant_data/management/status.py +55 -0
- bagelquant_data/pipeline/__init__.py +6 -0
- bagelquant_data/pipeline/commit.py +175 -0
- bagelquant_data/pipeline/ingest.py +142 -0
- bagelquant_data/pipeline/update.py +351 -0
- bagelquant_data/query/__init__.py +65 -0
- bagelquant_data/query/field.py +96 -0
- bagelquant_data/query/filters.py +18 -0
- bagelquant_data/query/observations.py +45 -0
- bagelquant_data/query/raw.py +46 -0
- bagelquant_data/query/records.py +17 -0
- bagelquant_data/query/reference.py +18 -0
- bagelquant_data/query/scanner.py +13 -0
- bagelquant_data/sources/__init__.py +1 -0
- bagelquant_data/sources/tushare/__init__.py +5 -0
- bagelquant_data/sources/tushare/authentication.py +16 -0
- bagelquant_data/sources/tushare/client.py +20 -0
- bagelquant_data/sources/tushare/source.py +121 -0
- bagelquant_data/storage/__init__.py +7 -0
- bagelquant_data/storage/atomic.py +29 -0
- bagelquant_data/storage/metadata.py +410 -0
- bagelquant_data/storage/parquet.py +68 -0
- bagelquant_data/storage/paths.py +48 -0
- bagelquant_data/storage/rejected.py +30 -0
- bagelquant_data/storage/staging.py +28 -0
- bagelquant_data-0.1.0.dist-info/METADATA +74 -0
- bagelquant_data-0.1.0.dist-info/RECORD +57 -0
- bagelquant_data-0.1.0.dist-info/WHEEL +4 -0
- bagelquant_data-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Framework validation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from bagelquant_data.core.dataset import DatasetSpec
|
|
10
|
+
from bagelquant_data.core.exceptions import ValidationError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Validator(Protocol):
|
|
14
|
+
"""Validate canonical records."""
|
|
15
|
+
|
|
16
|
+
def validate(self, frame: pl.LazyFrame, spec: DatasetSpec) -> None:
|
|
17
|
+
"""Raise on invalid data."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FrameworkValidator:
|
|
21
|
+
"""Basic schema validation shared by all datasets."""
|
|
22
|
+
|
|
23
|
+
def validate(self, frame: pl.LazyFrame, spec: DatasetSpec) -> None:
|
|
24
|
+
names = set(frame.collect_schema().names())
|
|
25
|
+
required = set(spec.required_columns)
|
|
26
|
+
if not spec.reference:
|
|
27
|
+
required.update({"asset_id", "time"})
|
|
28
|
+
missing = sorted(required - names)
|
|
29
|
+
if missing:
|
|
30
|
+
raise ValidationError(f"{spec.source}/{spec.name} missing columns: {missing}")
|
|
31
|
+
if spec.point_in_time and "period" not in names:
|
|
32
|
+
raise ValidationError(f"{spec.source}/{spec.name} PIT dataset requires period")
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Generic financial transformation API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import polars as pl
|
|
9
|
+
|
|
10
|
+
from bagelquant_data.core.types import DateLike
|
|
11
|
+
from bagelquant_data.finance.fields import FinancialFieldKind, FinancialFieldSpec
|
|
12
|
+
from bagelquant_data.finance.flows import ytd_to_period
|
|
13
|
+
from bagelquant_data.finance.point_in_time import asof
|
|
14
|
+
from bagelquant_data.finance.ratios import ratio
|
|
15
|
+
from bagelquant_data.finance.rolling import trailing
|
|
16
|
+
from bagelquant_data.finance.shares import weighted_average
|
|
17
|
+
from bagelquant_data.finance.stocks import average_stock
|
|
18
|
+
from bagelquant_data.query.raw import RawQueryService
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FinanceFacade:
|
|
22
|
+
"""User-facing generic finance API."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, raw_service: RawQueryService) -> None:
|
|
25
|
+
self._raw = raw_service
|
|
26
|
+
|
|
27
|
+
def field(
|
|
28
|
+
self,
|
|
29
|
+
dataset: str,
|
|
30
|
+
field: str,
|
|
31
|
+
*,
|
|
32
|
+
source: str,
|
|
33
|
+
start: DateLike | None = None,
|
|
34
|
+
end: DateLike | None = None,
|
|
35
|
+
assets: Sequence[str] | None = None,
|
|
36
|
+
value_name: str = "value",
|
|
37
|
+
) -> pl.LazyFrame:
|
|
38
|
+
return self._raw.raw(
|
|
39
|
+
dataset,
|
|
40
|
+
source=source,
|
|
41
|
+
start=start,
|
|
42
|
+
end=end,
|
|
43
|
+
assets=assets,
|
|
44
|
+
columns=("asset_id", "time", "period", field),
|
|
45
|
+
).select("asset_id", "time", "period", pl.col(field).alias(value_name))
|
|
46
|
+
|
|
47
|
+
def asof(self, data: pl.LazyFrame, observations: pl.LazyFrame, **kwargs: Any) -> pl.LazyFrame | pl.DataFrame:
|
|
48
|
+
return asof(data, observations, **kwargs)
|
|
49
|
+
|
|
50
|
+
def latest(
|
|
51
|
+
self,
|
|
52
|
+
dataset: str,
|
|
53
|
+
field: str,
|
|
54
|
+
*,
|
|
55
|
+
source: str,
|
|
56
|
+
observations: pl.LazyFrame,
|
|
57
|
+
value_name: str | None = None,
|
|
58
|
+
collect: bool = False,
|
|
59
|
+
) -> pl.LazyFrame | pl.DataFrame:
|
|
60
|
+
events = self.field(dataset, field, source=source, value_name=value_name or field)
|
|
61
|
+
return asof(events, observations, value_column=value_name or field, output_name=value_name or field, collect=collect)
|
|
62
|
+
|
|
63
|
+
ytd_to_period = staticmethod(ytd_to_period)
|
|
64
|
+
trailing = staticmethod(trailing)
|
|
65
|
+
average_stock = staticmethod(average_stock)
|
|
66
|
+
weighted_average = staticmethod(weighted_average)
|
|
67
|
+
ratio = staticmethod(ratio)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
__all__ = [
|
|
71
|
+
"FinanceFacade",
|
|
72
|
+
"FinancialFieldKind",
|
|
73
|
+
"FinancialFieldSpec",
|
|
74
|
+
"asof",
|
|
75
|
+
"average_stock",
|
|
76
|
+
"ratio",
|
|
77
|
+
"trailing",
|
|
78
|
+
"weighted_average",
|
|
79
|
+
"ytd_to_period",
|
|
80
|
+
]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Financial field metadata."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FinancialFieldKind(str, Enum):
|
|
10
|
+
"""Generic financial field semantics."""
|
|
11
|
+
|
|
12
|
+
FLOW_YTD = "flow_ytd"
|
|
13
|
+
FLOW_PERIOD = "flow_period"
|
|
14
|
+
STOCK = "stock"
|
|
15
|
+
PER_SHARE = "per_share"
|
|
16
|
+
RATIO = "ratio"
|
|
17
|
+
COUNT = "count"
|
|
18
|
+
UNKNOWN = "unknown"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, slots=True)
|
|
22
|
+
class FinancialFieldSpec:
|
|
23
|
+
"""Financial field metadata separate from storage."""
|
|
24
|
+
|
|
25
|
+
source: str
|
|
26
|
+
dataset: str
|
|
27
|
+
field: str
|
|
28
|
+
kind: FinancialFieldKind
|
|
29
|
+
unit: str | None = None
|
|
30
|
+
currency: str | None = None
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Generic flow transformations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ytd_to_period(
|
|
9
|
+
data: pl.LazyFrame,
|
|
10
|
+
*,
|
|
11
|
+
value_column: str = "value",
|
|
12
|
+
frequency: str = "quarter",
|
|
13
|
+
output_name: str | None = None,
|
|
14
|
+
) -> pl.LazyFrame:
|
|
15
|
+
"""Convert cumulative YTD flow values into period values."""
|
|
16
|
+
|
|
17
|
+
output = output_name or value_column
|
|
18
|
+
sorted_data = data.sort("asset_id", "period", "time")
|
|
19
|
+
year = pl.col("period").dt.year()
|
|
20
|
+
previous = pl.col(value_column).shift(1).over("asset_id", year)
|
|
21
|
+
period_value = pl.when(previous.is_null()).then(pl.col(value_column)).otherwise(pl.col(value_column) - previous)
|
|
22
|
+
return sorted_data.with_columns(period_value.alias(output)).select(
|
|
23
|
+
"asset_id", "time", "period", output
|
|
24
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Period helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def with_period_year(data: pl.LazyFrame) -> pl.LazyFrame:
|
|
9
|
+
"""Add period year for downstream grouping."""
|
|
10
|
+
|
|
11
|
+
return data.with_columns(pl.col("period").dt.year().alias("period_year"))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Point-in-time alignment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def asof(
|
|
9
|
+
data: pl.LazyFrame,
|
|
10
|
+
observations: pl.LazyFrame,
|
|
11
|
+
*,
|
|
12
|
+
value_column: str = "value",
|
|
13
|
+
output_name: str | None = None,
|
|
14
|
+
collect: bool = False,
|
|
15
|
+
) -> pl.LazyFrame | pl.DataFrame:
|
|
16
|
+
"""Align latest event whose availability time is not after observation time."""
|
|
17
|
+
|
|
18
|
+
output = output_name or value_column
|
|
19
|
+
events = data.select("asset_id", "time", pl.col(value_column)).sort("asset_id", "time")
|
|
20
|
+
obs = observations.select("time", "asset_id").sort("asset_id", "time")
|
|
21
|
+
result = obs.join_asof(
|
|
22
|
+
events,
|
|
23
|
+
on="time",
|
|
24
|
+
by="asset_id",
|
|
25
|
+
strategy="backward",
|
|
26
|
+
).select("time", "asset_id", pl.col(value_column).alias(output))
|
|
27
|
+
return result.collect() if collect else result
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Generic ratio operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ratio(
|
|
9
|
+
numerator: pl.LazyFrame,
|
|
10
|
+
denominator: pl.LazyFrame,
|
|
11
|
+
*,
|
|
12
|
+
numerator_column: str,
|
|
13
|
+
denominator_column: str,
|
|
14
|
+
output_name: str = "value",
|
|
15
|
+
zero_policy: str = "null",
|
|
16
|
+
) -> pl.LazyFrame:
|
|
17
|
+
"""Join numerator and denominator and compute a generic ratio."""
|
|
18
|
+
|
|
19
|
+
keys = ["asset_id", "time"]
|
|
20
|
+
if "period" in numerator.collect_schema().names() and "period" in denominator.collect_schema().names():
|
|
21
|
+
keys.append("period")
|
|
22
|
+
joined = numerator.join(denominator, on=keys, how="inner", suffix="__den")
|
|
23
|
+
denominator_expr = pl.col(denominator_column)
|
|
24
|
+
if zero_policy == "raise":
|
|
25
|
+
zero_count = joined.filter(denominator_expr == 0).select(pl.len()).collect().item()
|
|
26
|
+
if zero_count:
|
|
27
|
+
raise ZeroDivisionError("Ratio denominator contains zero")
|
|
28
|
+
value = pl.when(denominator_expr == 0).then(None if zero_policy == "null" else float("nan")).otherwise(
|
|
29
|
+
pl.col(numerator_column) / denominator_expr
|
|
30
|
+
)
|
|
31
|
+
return joined.with_columns(value.alias(output_name)).select(*keys, output_name)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Generic rolling financial operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def trailing(
|
|
9
|
+
data: pl.LazyFrame,
|
|
10
|
+
*,
|
|
11
|
+
value_column: str = "value",
|
|
12
|
+
periods: int,
|
|
13
|
+
operation: str,
|
|
14
|
+
output_name: str | None = None,
|
|
15
|
+
require_complete: bool = True,
|
|
16
|
+
) -> pl.LazyFrame:
|
|
17
|
+
"""Compute a trailing window over event-period rows."""
|
|
18
|
+
|
|
19
|
+
output = output_name or value_column
|
|
20
|
+
expr = {
|
|
21
|
+
"sum": pl.col(value_column).rolling_sum(periods),
|
|
22
|
+
"mean": pl.col(value_column).rolling_mean(periods),
|
|
23
|
+
"min": pl.col(value_column).rolling_min(periods),
|
|
24
|
+
"max": pl.col(value_column).rolling_max(periods),
|
|
25
|
+
"first": pl.col(value_column).rolling_map(lambda values: values[0], window_size=periods),
|
|
26
|
+
"last": pl.col(value_column).rolling_map(lambda values: values[-1], window_size=periods),
|
|
27
|
+
}[operation]
|
|
28
|
+
result = data.sort("asset_id", "period", "time").with_columns(
|
|
29
|
+
expr.over("asset_id").alias(output),
|
|
30
|
+
pl.len().rolling_sum(window_size=periods).over("asset_id").alias("__window_count"),
|
|
31
|
+
)
|
|
32
|
+
if require_complete:
|
|
33
|
+
result = result.filter(pl.col("__window_count") >= periods)
|
|
34
|
+
return result.drop("__window_count").select("asset_id", "time", "period", output)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Generic weighted-average support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def weighted_average(
|
|
9
|
+
data: pl.LazyFrame,
|
|
10
|
+
*,
|
|
11
|
+
value_column: str,
|
|
12
|
+
effective_time_column: str,
|
|
13
|
+
period_start_column: str,
|
|
14
|
+
period_end_column: str,
|
|
15
|
+
output_name: str | None = None,
|
|
16
|
+
) -> pl.LazyFrame:
|
|
17
|
+
"""Compute a generic time-weighted average inside each asset/period."""
|
|
18
|
+
|
|
19
|
+
output = output_name or value_column
|
|
20
|
+
weighted = data.with_columns(
|
|
21
|
+
(
|
|
22
|
+
pl.col(period_end_column).cast(pl.Date).sub(pl.col(period_start_column).cast(pl.Date)).dt.total_days()
|
|
23
|
+
).alias("__days")
|
|
24
|
+
).with_columns((pl.col(value_column) * pl.col("__days")).alias("__weighted"))
|
|
25
|
+
return weighted.group_by("asset_id", "period").agg(
|
|
26
|
+
pl.max("time").alias("time"),
|
|
27
|
+
(pl.sum("__weighted") / pl.sum("__days")).alias(output),
|
|
28
|
+
).select("asset_id", "time", "period", output)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Generic stock variable transformations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def average_stock(
|
|
9
|
+
data: pl.LazyFrame,
|
|
10
|
+
*,
|
|
11
|
+
value_column: str = "value",
|
|
12
|
+
periods: int = 4,
|
|
13
|
+
method: str = "endpoint",
|
|
14
|
+
output_name: str | None = None,
|
|
15
|
+
) -> pl.LazyFrame:
|
|
16
|
+
"""Average stock variables such as assets, equity, inventory, or shares."""
|
|
17
|
+
|
|
18
|
+
output = output_name or value_column
|
|
19
|
+
sorted_data = data.sort("asset_id", "period", "time")
|
|
20
|
+
if method == "endpoint":
|
|
21
|
+
lagged = pl.col(value_column).shift(periods).over("asset_id")
|
|
22
|
+
expr = ((pl.col(value_column) + lagged) / 2).alias(output)
|
|
23
|
+
elif method == "period_mean":
|
|
24
|
+
expr = pl.col(value_column).rolling_mean(periods).over("asset_id").alias(output)
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unsupported average_stock method: {method}")
|
|
27
|
+
return sorted_data.with_columns(expr).select("asset_id", "time", "period", output)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Dataset management API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import shutil
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from bagelquant_data.core.dataset import DatasetSpec
|
|
10
|
+
from bagelquant_data.core.exceptions import DatasetNotFoundError, DestructiveOperationError
|
|
11
|
+
from bagelquant_data.storage.metadata import MetadataStore
|
|
12
|
+
from bagelquant_data.storage.paths import LakePaths
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatasetManager:
|
|
16
|
+
"""Register and inspect dataset specifications."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, metadata: MetadataStore, paths: LakePaths) -> None:
|
|
19
|
+
self.metadata = metadata
|
|
20
|
+
self.paths = paths
|
|
21
|
+
self._specs: dict[tuple[str, str], DatasetSpec] = {}
|
|
22
|
+
|
|
23
|
+
def add(self, spec: DatasetSpec) -> None:
|
|
24
|
+
self.validate_spec(spec)
|
|
25
|
+
self._specs[spec.key] = spec
|
|
26
|
+
self.metadata.upsert_dataset(spec)
|
|
27
|
+
|
|
28
|
+
def add_from_yaml(self, path: str | Path) -> DatasetSpec:
|
|
29
|
+
spec = DatasetSpec.from_yaml(path)
|
|
30
|
+
self.add(spec)
|
|
31
|
+
return spec
|
|
32
|
+
|
|
33
|
+
def get(self, dataset: str, *, source: str) -> DatasetSpec:
|
|
34
|
+
key = (source, dataset)
|
|
35
|
+
if key in self._specs:
|
|
36
|
+
return self._specs[key]
|
|
37
|
+
row = self.metadata.get_dataset(source, dataset)
|
|
38
|
+
if row is None:
|
|
39
|
+
raise DatasetNotFoundError(f"Dataset is not registered: {source}/{dataset}")
|
|
40
|
+
spec = DatasetSpec.from_mapping(__import__("json").loads(row["spec_json"]))
|
|
41
|
+
self._specs[key] = spec
|
|
42
|
+
return spec
|
|
43
|
+
|
|
44
|
+
def list(self, source: str | None = None) -> list[dict[str, Any]]:
|
|
45
|
+
return self.metadata.list_datasets(source)
|
|
46
|
+
|
|
47
|
+
def enable(self, dataset: str, *, source: str) -> None:
|
|
48
|
+
self.metadata.set_dataset_enabled(source, dataset, True)
|
|
49
|
+
|
|
50
|
+
def disable(self, dataset: str, *, source: str) -> None:
|
|
51
|
+
self.metadata.set_dataset_enabled(source, dataset, False)
|
|
52
|
+
|
|
53
|
+
def validate_spec(self, spec: DatasetSpec) -> None:
|
|
54
|
+
if not spec.reference and ("asset_id" not in spec.required_columns or "time" not in spec.required_columns):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def remove(self, dataset: str, *, source: str, delete_data: bool = False, confirm: bool = False) -> None:
|
|
58
|
+
if delete_data and not confirm:
|
|
59
|
+
raise DestructiveOperationError("Pass confirm=True to delete canonical data")
|
|
60
|
+
self.metadata.remove_dataset(source, dataset)
|
|
61
|
+
self._specs.pop((source, dataset), None)
|
|
62
|
+
if delete_data:
|
|
63
|
+
shutil.rmtree(self.paths.dataset_root(source, dataset), ignore_errors=True)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Public DataLake facade."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import polars as pl
|
|
11
|
+
|
|
12
|
+
from bagelquant_data.core.dataset import DatasetSpec
|
|
13
|
+
from bagelquant_data.core.exceptions import ConfigurationError, DatasetNotFoundError
|
|
14
|
+
from bagelquant_data.core.registry import FrameworkRegistries, default_registries
|
|
15
|
+
from bagelquant_data.core.request import RequestContext
|
|
16
|
+
from bagelquant_data.finance import FinanceFacade
|
|
17
|
+
from bagelquant_data.management.datasets import DatasetManager
|
|
18
|
+
from bagelquant_data.management.sources import SourceManager
|
|
19
|
+
from bagelquant_data.management.status import StatusManager
|
|
20
|
+
from bagelquant_data.pipeline.ingest import IngestionPipeline, IngestionReport
|
|
21
|
+
from bagelquant_data.pipeline.update import UpdateReport, combine_reports, update_dataset
|
|
22
|
+
from bagelquant_data.query import QueryFacade
|
|
23
|
+
from bagelquant_data.query.raw import RawQueryService
|
|
24
|
+
from bagelquant_data.storage.metadata import MetadataStore
|
|
25
|
+
from bagelquant_data.storage.parquet import ParquetStore
|
|
26
|
+
from bagelquant_data.storage.paths import LakePaths
|
|
27
|
+
from bagelquant_data.storage.rejected import RejectedStore
|
|
28
|
+
from bagelquant_data.storage.staging import StagingStore
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DataLake:
|
|
32
|
+
"""Source-agnostic local data lake facade."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, root: str | Path, registries: FrameworkRegistries | None = None) -> None:
|
|
35
|
+
self.paths = LakePaths.open(root)
|
|
36
|
+
self.paths.ensure()
|
|
37
|
+
self.registries = registries or default_registries()
|
|
38
|
+
self.metadata = MetadataStore(self.paths.database)
|
|
39
|
+
self.parquet = ParquetStore(self.paths, self.metadata)
|
|
40
|
+
self.sources = SourceManager(self.registries, self.metadata)
|
|
41
|
+
self.datasets = DatasetManager(self.metadata, self.paths)
|
|
42
|
+
self.status = StatusManager(self.metadata)
|
|
43
|
+
raw = RawQueryService(self.parquet, self.metadata)
|
|
44
|
+
self.query = QueryFacade(raw)
|
|
45
|
+
self.finance = FinanceFacade(raw)
|
|
46
|
+
self.update = UpdateManager(self)
|
|
47
|
+
self._pipeline = IngestionPipeline(
|
|
48
|
+
registries=self.registries,
|
|
49
|
+
parquet=self.parquet,
|
|
50
|
+
metadata=self.metadata,
|
|
51
|
+
staging=StagingStore(self.paths),
|
|
52
|
+
rejected=RejectedStore(self.paths),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def open(cls, root: str | Path = "data") -> "DataLake":
|
|
57
|
+
"""Open or create a local data lake."""
|
|
58
|
+
|
|
59
|
+
return cls(root)
|
|
60
|
+
|
|
61
|
+
def ingest_frame(self, spec: DatasetSpec, frame: pl.DataFrame) -> IngestionReport:
|
|
62
|
+
"""Convenience method for tests and local file adapters."""
|
|
63
|
+
|
|
64
|
+
self.datasets.add(spec)
|
|
65
|
+
return self._pipeline.ingest_frame(spec, frame, mode=spec.update_mode)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class UpdateManager:
|
|
70
|
+
"""Public update API."""
|
|
71
|
+
|
|
72
|
+
lake: DataLake
|
|
73
|
+
|
|
74
|
+
def dataset(self, dataset: str, *, source: str, **kwargs: Any) -> IngestionReport:
|
|
75
|
+
spec = self.lake.datasets.get(dataset, source=source)
|
|
76
|
+
adapter = self.lake.sources.get(source)
|
|
77
|
+
if spec.request_planner == "by_asset" and not kwargs.get("assets"):
|
|
78
|
+
kwargs["assets"] = self._default_assets(source)
|
|
79
|
+
if spec.source == "tushare" and spec.category == "market":
|
|
80
|
+
kwargs["trade_dates"] = self._trade_dates(source, start=kwargs.get("start"), end=kwargs.get("end"))
|
|
81
|
+
context = _request_context(source=source, dataset=dataset, kwargs=kwargs)
|
|
82
|
+
return update_dataset(spec=spec, source_adapter=adapter, pipeline=self.lake._pipeline, context=context)
|
|
83
|
+
|
|
84
|
+
def datasets(self, datasets: list[str], *, source: str, **kwargs: Any) -> UpdateReport:
|
|
85
|
+
reports = [self.dataset(dataset, source=source, **kwargs) for dataset in datasets]
|
|
86
|
+
return combine_reports(source, reports)
|
|
87
|
+
|
|
88
|
+
def source(self, source: str, **kwargs: Any) -> UpdateReport:
|
|
89
|
+
names = [row["name"] for row in self.lake.datasets.list(source) if row["enabled"]]
|
|
90
|
+
return self.datasets(names, source=source, **kwargs)
|
|
91
|
+
|
|
92
|
+
def _default_assets(self, source: str) -> list[str]:
|
|
93
|
+
if source != "tushare":
|
|
94
|
+
raise ConfigurationError("assets=... is required for by_asset dataset updates")
|
|
95
|
+
try:
|
|
96
|
+
frame = self.lake.query.reference("stock_basic", source=source, collect=True)
|
|
97
|
+
except DatasetNotFoundError as exc:
|
|
98
|
+
raise ConfigurationError(
|
|
99
|
+
"Tushare by_asset updates require an asset universe. "
|
|
100
|
+
"Update/register stock_basic first or pass assets=[...] explicitly."
|
|
101
|
+
) from exc
|
|
102
|
+
if isinstance(frame, pl.LazyFrame):
|
|
103
|
+
frame = frame.collect()
|
|
104
|
+
columns = frame.columns
|
|
105
|
+
column = "asset_id" if "asset_id" in columns else "ts_code" if "ts_code" in columns else None
|
|
106
|
+
if column is None:
|
|
107
|
+
raise ConfigurationError("tushare/stock_basic does not contain asset_id or ts_code")
|
|
108
|
+
return [str(value) for value in frame.get_column(column).drop_nulls().unique().sort().to_list()]
|
|
109
|
+
|
|
110
|
+
def _trade_dates(self, source: str, *, start: Any, end: Any) -> list[str]:
|
|
111
|
+
try:
|
|
112
|
+
frame = self.lake.query.reference("trade_cal", source=source, collect=True)
|
|
113
|
+
except DatasetNotFoundError as exc:
|
|
114
|
+
raise ConfigurationError(
|
|
115
|
+
"Tushare market updates require trade_cal. Update/register trade_cal first."
|
|
116
|
+
) from exc
|
|
117
|
+
if isinstance(frame, pl.LazyFrame):
|
|
118
|
+
frame = frame.collect()
|
|
119
|
+
if frame.is_empty():
|
|
120
|
+
raise ConfigurationError("tushare/trade_cal is empty; update trade_cal before market datasets")
|
|
121
|
+
column = "time" if "time" in frame.columns else "cal_date" if "cal_date" in frame.columns else None
|
|
122
|
+
if column is None:
|
|
123
|
+
raise ConfigurationError("tushare/trade_cal does not contain time or cal_date")
|
|
124
|
+
dates = frame.with_columns(_calendar_date_expr(column).alias("_calendar_date"))
|
|
125
|
+
if start is not None:
|
|
126
|
+
dates = dates.filter(pl.col("_calendar_date") >= _date_literal(start))
|
|
127
|
+
if end is not None:
|
|
128
|
+
dates = dates.filter(pl.col("_calendar_date") <= _date_literal(end))
|
|
129
|
+
if "is_open" in dates.columns:
|
|
130
|
+
dates = dates.filter(pl.col("is_open").cast(pl.Int8, strict=False) == 1)
|
|
131
|
+
result = [
|
|
132
|
+
value.strftime("%Y-%m-%d")
|
|
133
|
+
for value in dates.select("_calendar_date").drop_nulls().unique().sort("_calendar_date").to_series().to_list()
|
|
134
|
+
]
|
|
135
|
+
if not result:
|
|
136
|
+
raise ConfigurationError(
|
|
137
|
+
f"tushare/trade_cal has no open trading days between {start} and {end}"
|
|
138
|
+
)
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _request_context(source: str, dataset: str, kwargs: dict[str, Any]) -> RequestContext:
|
|
143
|
+
known = {
|
|
144
|
+
"start": kwargs.pop("start", None),
|
|
145
|
+
"end": kwargs.pop("end", None),
|
|
146
|
+
"assets": kwargs.pop("assets", None),
|
|
147
|
+
}
|
|
148
|
+
workers = kwargs.pop("workers", None)
|
|
149
|
+
batch_size = kwargs.pop("batch_size", None)
|
|
150
|
+
source_options = kwargs.pop("source_options", None)
|
|
151
|
+
progress = kwargs.pop("progress", None)
|
|
152
|
+
max_retries = kwargs.pop("max_retries", None)
|
|
153
|
+
retry_backoff_seconds = kwargs.pop("retry_backoff_seconds", None)
|
|
154
|
+
trade_dates = kwargs.pop("trade_dates", None)
|
|
155
|
+
if kwargs:
|
|
156
|
+
keys = ", ".join(sorted(kwargs))
|
|
157
|
+
raise ConfigurationError(f"Unsupported update option(s): {keys}")
|
|
158
|
+
options: dict[str, Any] = {}
|
|
159
|
+
if workers is not None:
|
|
160
|
+
options["workers"] = workers
|
|
161
|
+
if batch_size is not None:
|
|
162
|
+
options["batch_size"] = batch_size
|
|
163
|
+
if source_options is not None:
|
|
164
|
+
options["source_options"] = source_options
|
|
165
|
+
if progress is not None:
|
|
166
|
+
options["progress"] = progress
|
|
167
|
+
if max_retries is not None:
|
|
168
|
+
options["max_retries"] = max_retries
|
|
169
|
+
if retry_backoff_seconds is not None:
|
|
170
|
+
options["retry_backoff_seconds"] = retry_backoff_seconds
|
|
171
|
+
if trade_dates is not None:
|
|
172
|
+
options["trade_dates"] = trade_dates
|
|
173
|
+
return RequestContext(source=source, dataset=dataset, options=options, **known)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _calendar_date_expr(column: str) -> pl.Expr:
|
|
177
|
+
return (
|
|
178
|
+
pl.when(pl.col(column).cast(pl.String).str.len_chars() == 8)
|
|
179
|
+
.then(pl.col(column).cast(pl.String).str.strptime(pl.Date, "%Y%m%d", strict=False))
|
|
180
|
+
.otherwise(pl.col(column).cast(pl.Date, strict=False))
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _date_literal(value: Any) -> pl.Expr:
|
|
185
|
+
if hasattr(value, "strftime"):
|
|
186
|
+
return pl.lit(value).cast(pl.Date, strict=False)
|
|
187
|
+
text = str(value)
|
|
188
|
+
if "T" in text:
|
|
189
|
+
text = text.split("T", maxsplit=1)[0]
|
|
190
|
+
return pl.lit(datetime.strptime(text[:10], "%Y-%m-%d").date())
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Source management API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from bagelquant_data.core.exceptions import SourceNotFoundError
|
|
8
|
+
from bagelquant_data.core.registry import FrameworkRegistries
|
|
9
|
+
from bagelquant_data.storage.metadata import MetadataStore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SourceManager:
|
|
13
|
+
"""Register and configure source adapters."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, registries: FrameworkRegistries, metadata: MetadataStore) -> None:
|
|
16
|
+
self.registries = registries
|
|
17
|
+
self.metadata = metadata
|
|
18
|
+
|
|
19
|
+
def register(self, source: object) -> None:
|
|
20
|
+
name = getattr(source, "name")
|
|
21
|
+
if callable(name):
|
|
22
|
+
name = name()
|
|
23
|
+
saved_options = self.metadata.source_options(str(name))
|
|
24
|
+
if saved_options and hasattr(source, "configure"):
|
|
25
|
+
source.configure(**saved_options) # type: ignore[attr-defined]
|
|
26
|
+
self.registries.sources.register(str(name), source)
|
|
27
|
+
self.metadata.upsert_source(
|
|
28
|
+
str(name),
|
|
29
|
+
type(source).__name__,
|
|
30
|
+
configured=bool(saved_options),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def remove(self, name: str) -> None:
|
|
34
|
+
self.registries.sources._items.pop(name, None)
|
|
35
|
+
self.metadata.remove_source(name)
|
|
36
|
+
|
|
37
|
+
def list(self) -> list[dict[str, Any]]:
|
|
38
|
+
return self.metadata.list_sources()
|
|
39
|
+
|
|
40
|
+
def get(self, name: str) -> object:
|
|
41
|
+
try:
|
|
42
|
+
return self.registries.sources.get(name)
|
|
43
|
+
except KeyError as exc:
|
|
44
|
+
raise SourceNotFoundError(f"Source is not registered: {name}") from exc
|
|
45
|
+
|
|
46
|
+
def configure(self, name: str, **options: Any) -> None:
|
|
47
|
+
source = self.get(name)
|
|
48
|
+
source.configure(**options) # type: ignore[attr-defined]
|
|
49
|
+
saved = self.metadata.source_options(name)
|
|
50
|
+
saved.update(options)
|
|
51
|
+
self.metadata.upsert_source(name, type(source).__name__, configured=True, options=saved)
|
|
52
|
+
|
|
53
|
+
def configure_tushare(self, token: str) -> None:
|
|
54
|
+
self.configure("tushare", token=token)
|
|
55
|
+
|
|
56
|
+
def test(self, name: str) -> None:
|
|
57
|
+
self.get(name).test_connection() # type: ignore[attr-defined]
|