bagelquant-data 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. bagelquant_data/__init__.py +32 -0
  2. bagelquant_data/cli/main.py +30 -0
  3. bagelquant_data/core/__init__.py +37 -0
  4. bagelquant_data/core/dataset.py +170 -0
  5. bagelquant_data/core/deduplication.py +40 -0
  6. bagelquant_data/core/exceptions.py +39 -0
  7. bagelquant_data/core/hashing.py +32 -0
  8. bagelquant_data/core/normalization.py +67 -0
  9. bagelquant_data/core/partitioning.py +76 -0
  10. bagelquant_data/core/registry.py +67 -0
  11. bagelquant_data/core/request.py +21 -0
  12. bagelquant_data/core/source.py +38 -0
  13. bagelquant_data/core/types.py +10 -0
  14. bagelquant_data/core/validation.py +32 -0
  15. bagelquant_data/finance/__init__.py +80 -0
  16. bagelquant_data/finance/align.py +5 -0
  17. bagelquant_data/finance/fields.py +30 -0
  18. bagelquant_data/finance/flows.py +24 -0
  19. bagelquant_data/finance/periods.py +11 -0
  20. bagelquant_data/finance/point_in_time.py +27 -0
  21. bagelquant_data/finance/ratios.py +31 -0
  22. bagelquant_data/finance/rolling.py +34 -0
  23. bagelquant_data/finance/shares.py +28 -0
  24. bagelquant_data/finance/stocks.py +27 -0
  25. bagelquant_data/management/__init__.py +5 -0
  26. bagelquant_data/management/datasets.py +63 -0
  27. bagelquant_data/management/lake.py +190 -0
  28. bagelquant_data/management/sources.py +57 -0
  29. bagelquant_data/management/status.py +55 -0
  30. bagelquant_data/pipeline/__init__.py +6 -0
  31. bagelquant_data/pipeline/commit.py +175 -0
  32. bagelquant_data/pipeline/ingest.py +142 -0
  33. bagelquant_data/pipeline/update.py +351 -0
  34. bagelquant_data/query/__init__.py +65 -0
  35. bagelquant_data/query/field.py +96 -0
  36. bagelquant_data/query/filters.py +18 -0
  37. bagelquant_data/query/observations.py +45 -0
  38. bagelquant_data/query/raw.py +46 -0
  39. bagelquant_data/query/records.py +17 -0
  40. bagelquant_data/query/reference.py +18 -0
  41. bagelquant_data/query/scanner.py +13 -0
  42. bagelquant_data/sources/__init__.py +1 -0
  43. bagelquant_data/sources/tushare/__init__.py +5 -0
  44. bagelquant_data/sources/tushare/authentication.py +16 -0
  45. bagelquant_data/sources/tushare/client.py +20 -0
  46. bagelquant_data/sources/tushare/source.py +121 -0
  47. bagelquant_data/storage/__init__.py +7 -0
  48. bagelquant_data/storage/atomic.py +29 -0
  49. bagelquant_data/storage/metadata.py +410 -0
  50. bagelquant_data/storage/parquet.py +68 -0
  51. bagelquant_data/storage/paths.py +48 -0
  52. bagelquant_data/storage/rejected.py +30 -0
  53. bagelquant_data/storage/staging.py +28 -0
  54. bagelquant_data-0.1.0.dist-info/METADATA +74 -0
  55. bagelquant_data-0.1.0.dist-info/RECORD +57 -0
  56. bagelquant_data-0.1.0.dist-info/WHEEL +4 -0
  57. bagelquant_data-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,32 @@
1
+ """Framework validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import polars as pl
8
+
9
+ from bagelquant_data.core.dataset import DatasetSpec
10
+ from bagelquant_data.core.exceptions import ValidationError
11
+
12
+
13
+ class Validator(Protocol):
14
+ """Validate canonical records."""
15
+
16
+ def validate(self, frame: pl.LazyFrame, spec: DatasetSpec) -> None:
17
+ """Raise on invalid data."""
18
+
19
+
20
+ class FrameworkValidator:
21
+ """Basic schema validation shared by all datasets."""
22
+
23
+ def validate(self, frame: pl.LazyFrame, spec: DatasetSpec) -> None:
24
+ names = set(frame.collect_schema().names())
25
+ required = set(spec.required_columns)
26
+ if not spec.reference:
27
+ required.update({"asset_id", "time"})
28
+ missing = sorted(required - names)
29
+ if missing:
30
+ raise ValidationError(f"{spec.source}/{spec.name} missing columns: {missing}")
31
+ if spec.point_in_time and "period" not in names:
32
+ raise ValidationError(f"{spec.source}/{spec.name} PIT dataset requires period")
@@ -0,0 +1,80 @@
1
+ """Generic financial transformation API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Any
7
+
8
+ import polars as pl
9
+
10
+ from bagelquant_data.core.types import DateLike
11
+ from bagelquant_data.finance.fields import FinancialFieldKind, FinancialFieldSpec
12
+ from bagelquant_data.finance.flows import ytd_to_period
13
+ from bagelquant_data.finance.point_in_time import asof
14
+ from bagelquant_data.finance.ratios import ratio
15
+ from bagelquant_data.finance.rolling import trailing
16
+ from bagelquant_data.finance.shares import weighted_average
17
+ from bagelquant_data.finance.stocks import average_stock
18
+ from bagelquant_data.query.raw import RawQueryService
19
+
20
+
21
+ class FinanceFacade:
22
+ """User-facing generic finance API."""
23
+
24
+ def __init__(self, raw_service: RawQueryService) -> None:
25
+ self._raw = raw_service
26
+
27
+ def field(
28
+ self,
29
+ dataset: str,
30
+ field: str,
31
+ *,
32
+ source: str,
33
+ start: DateLike | None = None,
34
+ end: DateLike | None = None,
35
+ assets: Sequence[str] | None = None,
36
+ value_name: str = "value",
37
+ ) -> pl.LazyFrame:
38
+ return self._raw.raw(
39
+ dataset,
40
+ source=source,
41
+ start=start,
42
+ end=end,
43
+ assets=assets,
44
+ columns=("asset_id", "time", "period", field),
45
+ ).select("asset_id", "time", "period", pl.col(field).alias(value_name))
46
+
47
+ def asof(self, data: pl.LazyFrame, observations: pl.LazyFrame, **kwargs: Any) -> pl.LazyFrame | pl.DataFrame:
48
+ return asof(data, observations, **kwargs)
49
+
50
+ def latest(
51
+ self,
52
+ dataset: str,
53
+ field: str,
54
+ *,
55
+ source: str,
56
+ observations: pl.LazyFrame,
57
+ value_name: str | None = None,
58
+ collect: bool = False,
59
+ ) -> pl.LazyFrame | pl.DataFrame:
60
+ events = self.field(dataset, field, source=source, value_name=value_name or field)
61
+ return asof(events, observations, value_column=value_name or field, output_name=value_name or field, collect=collect)
62
+
63
+ ytd_to_period = staticmethod(ytd_to_period)
64
+ trailing = staticmethod(trailing)
65
+ average_stock = staticmethod(average_stock)
66
+ weighted_average = staticmethod(weighted_average)
67
+ ratio = staticmethod(ratio)
68
+
69
+
70
+ __all__ = [
71
+ "FinanceFacade",
72
+ "FinancialFieldKind",
73
+ "FinancialFieldSpec",
74
+ "asof",
75
+ "average_stock",
76
+ "ratio",
77
+ "trailing",
78
+ "weighted_average",
79
+ "ytd_to_period",
80
+ ]
@@ -0,0 +1,5 @@
1
+ """Alignment helpers."""
2
+
3
+ from bagelquant_data.finance.point_in_time import asof
4
+
5
+ __all__ = ["asof"]
@@ -0,0 +1,30 @@
1
+ """Financial field metadata."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+
8
+
9
+ class FinancialFieldKind(str, Enum):
10
+ """Generic financial field semantics."""
11
+
12
+ FLOW_YTD = "flow_ytd"
13
+ FLOW_PERIOD = "flow_period"
14
+ STOCK = "stock"
15
+ PER_SHARE = "per_share"
16
+ RATIO = "ratio"
17
+ COUNT = "count"
18
+ UNKNOWN = "unknown"
19
+
20
+
21
+ @dataclass(frozen=True, slots=True)
22
+ class FinancialFieldSpec:
23
+ """Financial field metadata separate from storage."""
24
+
25
+ source: str
26
+ dataset: str
27
+ field: str
28
+ kind: FinancialFieldKind
29
+ unit: str | None = None
30
+ currency: str | None = None
@@ -0,0 +1,24 @@
1
+ """Generic flow transformations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def ytd_to_period(
9
+ data: pl.LazyFrame,
10
+ *,
11
+ value_column: str = "value",
12
+ frequency: str = "quarter",
13
+ output_name: str | None = None,
14
+ ) -> pl.LazyFrame:
15
+ """Convert cumulative YTD flow values into period values."""
16
+
17
+ output = output_name or value_column
18
+ sorted_data = data.sort("asset_id", "period", "time")
19
+ year = pl.col("period").dt.year()
20
+ previous = pl.col(value_column).shift(1).over("asset_id", year)
21
+ period_value = pl.when(previous.is_null()).then(pl.col(value_column)).otherwise(pl.col(value_column) - previous)
22
+ return sorted_data.with_columns(period_value.alias(output)).select(
23
+ "asset_id", "time", "period", output
24
+ )
@@ -0,0 +1,11 @@
1
+ """Period helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def with_period_year(data: pl.LazyFrame) -> pl.LazyFrame:
9
+ """Add period year for downstream grouping."""
10
+
11
+ return data.with_columns(pl.col("period").dt.year().alias("period_year"))
@@ -0,0 +1,27 @@
1
+ """Point-in-time alignment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def asof(
9
+ data: pl.LazyFrame,
10
+ observations: pl.LazyFrame,
11
+ *,
12
+ value_column: str = "value",
13
+ output_name: str | None = None,
14
+ collect: bool = False,
15
+ ) -> pl.LazyFrame | pl.DataFrame:
16
+ """Align latest event whose availability time is not after observation time."""
17
+
18
+ output = output_name or value_column
19
+ events = data.select("asset_id", "time", pl.col(value_column)).sort("asset_id", "time")
20
+ obs = observations.select("time", "asset_id").sort("asset_id", "time")
21
+ result = obs.join_asof(
22
+ events,
23
+ on="time",
24
+ by="asset_id",
25
+ strategy="backward",
26
+ ).select("time", "asset_id", pl.col(value_column).alias(output))
27
+ return result.collect() if collect else result
@@ -0,0 +1,31 @@
1
+ """Generic ratio operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def ratio(
9
+ numerator: pl.LazyFrame,
10
+ denominator: pl.LazyFrame,
11
+ *,
12
+ numerator_column: str,
13
+ denominator_column: str,
14
+ output_name: str = "value",
15
+ zero_policy: str = "null",
16
+ ) -> pl.LazyFrame:
17
+ """Join numerator and denominator and compute a generic ratio."""
18
+
19
+ keys = ["asset_id", "time"]
20
+ if "period" in numerator.collect_schema().names() and "period" in denominator.collect_schema().names():
21
+ keys.append("period")
22
+ joined = numerator.join(denominator, on=keys, how="inner", suffix="__den")
23
+ denominator_expr = pl.col(denominator_column)
24
+ if zero_policy == "raise":
25
+ zero_count = joined.filter(denominator_expr == 0).select(pl.len()).collect().item()
26
+ if zero_count:
27
+ raise ZeroDivisionError("Ratio denominator contains zero")
28
+ value = pl.when(denominator_expr == 0).then(None if zero_policy == "null" else float("nan")).otherwise(
29
+ pl.col(numerator_column) / denominator_expr
30
+ )
31
+ return joined.with_columns(value.alias(output_name)).select(*keys, output_name)
@@ -0,0 +1,34 @@
1
+ """Generic rolling financial operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def trailing(
9
+ data: pl.LazyFrame,
10
+ *,
11
+ value_column: str = "value",
12
+ periods: int,
13
+ operation: str,
14
+ output_name: str | None = None,
15
+ require_complete: bool = True,
16
+ ) -> pl.LazyFrame:
17
+ """Compute a trailing window over event-period rows."""
18
+
19
+ output = output_name or value_column
20
+ expr = {
21
+ "sum": pl.col(value_column).rolling_sum(periods),
22
+ "mean": pl.col(value_column).rolling_mean(periods),
23
+ "min": pl.col(value_column).rolling_min(periods),
24
+ "max": pl.col(value_column).rolling_max(periods),
25
+ "first": pl.col(value_column).rolling_map(lambda values: values[0], window_size=periods),
26
+ "last": pl.col(value_column).rolling_map(lambda values: values[-1], window_size=periods),
27
+ }[operation]
28
+ result = data.sort("asset_id", "period", "time").with_columns(
29
+ expr.over("asset_id").alias(output),
30
+ pl.len().rolling_sum(window_size=periods).over("asset_id").alias("__window_count"),
31
+ )
32
+ if require_complete:
33
+ result = result.filter(pl.col("__window_count") >= periods)
34
+ return result.drop("__window_count").select("asset_id", "time", "period", output)
@@ -0,0 +1,28 @@
1
+ """Generic weighted-average support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def weighted_average(
9
+ data: pl.LazyFrame,
10
+ *,
11
+ value_column: str,
12
+ effective_time_column: str,
13
+ period_start_column: str,
14
+ period_end_column: str,
15
+ output_name: str | None = None,
16
+ ) -> pl.LazyFrame:
17
+ """Compute a generic time-weighted average inside each asset/period."""
18
+
19
+ output = output_name or value_column
20
+ weighted = data.with_columns(
21
+ (
22
+ pl.col(period_end_column).cast(pl.Date).sub(pl.col(period_start_column).cast(pl.Date)).dt.total_days()
23
+ ).alias("__days")
24
+ ).with_columns((pl.col(value_column) * pl.col("__days")).alias("__weighted"))
25
+ return weighted.group_by("asset_id", "period").agg(
26
+ pl.max("time").alias("time"),
27
+ (pl.sum("__weighted") / pl.sum("__days")).alias(output),
28
+ ).select("asset_id", "time", "period", output)
@@ -0,0 +1,27 @@
1
+ """Generic stock variable transformations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import polars as pl
6
+
7
+
8
+ def average_stock(
9
+ data: pl.LazyFrame,
10
+ *,
11
+ value_column: str = "value",
12
+ periods: int = 4,
13
+ method: str = "endpoint",
14
+ output_name: str | None = None,
15
+ ) -> pl.LazyFrame:
16
+ """Average stock variables such as assets, equity, inventory, or shares."""
17
+
18
+ output = output_name or value_column
19
+ sorted_data = data.sort("asset_id", "period", "time")
20
+ if method == "endpoint":
21
+ lagged = pl.col(value_column).shift(periods).over("asset_id")
22
+ expr = ((pl.col(value_column) + lagged) / 2).alias(output)
23
+ elif method == "period_mean":
24
+ expr = pl.col(value_column).rolling_mean(periods).over("asset_id").alias(output)
25
+ else:
26
+ raise ValueError(f"Unsupported average_stock method: {method}")
27
+ return sorted_data.with_columns(expr).select("asset_id", "time", "period", output)
@@ -0,0 +1,5 @@
1
+ """Management API."""
2
+
3
+ from bagelquant_data.management.lake import DataLake
4
+
5
+ __all__ = ["DataLake"]
@@ -0,0 +1,63 @@
1
+ """Dataset management API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from bagelquant_data.core.dataset import DatasetSpec
10
+ from bagelquant_data.core.exceptions import DatasetNotFoundError, DestructiveOperationError
11
+ from bagelquant_data.storage.metadata import MetadataStore
12
+ from bagelquant_data.storage.paths import LakePaths
13
+
14
+
15
+ class DatasetManager:
16
+ """Register and inspect dataset specifications."""
17
+
18
+ def __init__(self, metadata: MetadataStore, paths: LakePaths) -> None:
19
+ self.metadata = metadata
20
+ self.paths = paths
21
+ self._specs: dict[tuple[str, str], DatasetSpec] = {}
22
+
23
+ def add(self, spec: DatasetSpec) -> None:
24
+ self.validate_spec(spec)
25
+ self._specs[spec.key] = spec
26
+ self.metadata.upsert_dataset(spec)
27
+
28
+ def add_from_yaml(self, path: str | Path) -> DatasetSpec:
29
+ spec = DatasetSpec.from_yaml(path)
30
+ self.add(spec)
31
+ return spec
32
+
33
+ def get(self, dataset: str, *, source: str) -> DatasetSpec:
34
+ key = (source, dataset)
35
+ if key in self._specs:
36
+ return self._specs[key]
37
+ row = self.metadata.get_dataset(source, dataset)
38
+ if row is None:
39
+ raise DatasetNotFoundError(f"Dataset is not registered: {source}/{dataset}")
40
+ spec = DatasetSpec.from_mapping(__import__("json").loads(row["spec_json"]))
41
+ self._specs[key] = spec
42
+ return spec
43
+
44
+ def list(self, source: str | None = None) -> list[dict[str, Any]]:
45
+ return self.metadata.list_datasets(source)
46
+
47
+ def enable(self, dataset: str, *, source: str) -> None:
48
+ self.metadata.set_dataset_enabled(source, dataset, True)
49
+
50
+ def disable(self, dataset: str, *, source: str) -> None:
51
+ self.metadata.set_dataset_enabled(source, dataset, False)
52
+
53
+ def validate_spec(self, spec: DatasetSpec) -> None:
54
+ if not spec.reference and ("asset_id" not in spec.required_columns or "time" not in spec.required_columns):
55
+ pass
56
+
57
+ def remove(self, dataset: str, *, source: str, delete_data: bool = False, confirm: bool = False) -> None:
58
+ if delete_data and not confirm:
59
+ raise DestructiveOperationError("Pass confirm=True to delete canonical data")
60
+ self.metadata.remove_dataset(source, dataset)
61
+ self._specs.pop((source, dataset), None)
62
+ if delete_data:
63
+ shutil.rmtree(self.paths.dataset_root(source, dataset), ignore_errors=True)
@@ -0,0 +1,190 @@
1
+ """Public DataLake facade."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import polars as pl
11
+
12
+ from bagelquant_data.core.dataset import DatasetSpec
13
+ from bagelquant_data.core.exceptions import ConfigurationError, DatasetNotFoundError
14
+ from bagelquant_data.core.registry import FrameworkRegistries, default_registries
15
+ from bagelquant_data.core.request import RequestContext
16
+ from bagelquant_data.finance import FinanceFacade
17
+ from bagelquant_data.management.datasets import DatasetManager
18
+ from bagelquant_data.management.sources import SourceManager
19
+ from bagelquant_data.management.status import StatusManager
20
+ from bagelquant_data.pipeline.ingest import IngestionPipeline, IngestionReport
21
+ from bagelquant_data.pipeline.update import UpdateReport, combine_reports, update_dataset
22
+ from bagelquant_data.query import QueryFacade
23
+ from bagelquant_data.query.raw import RawQueryService
24
+ from bagelquant_data.storage.metadata import MetadataStore
25
+ from bagelquant_data.storage.parquet import ParquetStore
26
+ from bagelquant_data.storage.paths import LakePaths
27
+ from bagelquant_data.storage.rejected import RejectedStore
28
+ from bagelquant_data.storage.staging import StagingStore
29
+
30
+
31
+ class DataLake:
32
+ """Source-agnostic local data lake facade."""
33
+
34
+ def __init__(self, root: str | Path, registries: FrameworkRegistries | None = None) -> None:
35
+ self.paths = LakePaths.open(root)
36
+ self.paths.ensure()
37
+ self.registries = registries or default_registries()
38
+ self.metadata = MetadataStore(self.paths.database)
39
+ self.parquet = ParquetStore(self.paths, self.metadata)
40
+ self.sources = SourceManager(self.registries, self.metadata)
41
+ self.datasets = DatasetManager(self.metadata, self.paths)
42
+ self.status = StatusManager(self.metadata)
43
+ raw = RawQueryService(self.parquet, self.metadata)
44
+ self.query = QueryFacade(raw)
45
+ self.finance = FinanceFacade(raw)
46
+ self.update = UpdateManager(self)
47
+ self._pipeline = IngestionPipeline(
48
+ registries=self.registries,
49
+ parquet=self.parquet,
50
+ metadata=self.metadata,
51
+ staging=StagingStore(self.paths),
52
+ rejected=RejectedStore(self.paths),
53
+ )
54
+
55
+ @classmethod
56
+ def open(cls, root: str | Path = "data") -> "DataLake":
57
+ """Open or create a local data lake."""
58
+
59
+ return cls(root)
60
+
61
+ def ingest_frame(self, spec: DatasetSpec, frame: pl.DataFrame) -> IngestionReport:
62
+ """Convenience method for tests and local file adapters."""
63
+
64
+ self.datasets.add(spec)
65
+ return self._pipeline.ingest_frame(spec, frame, mode=spec.update_mode)
66
+
67
+
68
+ @dataclass
69
+ class UpdateManager:
70
+ """Public update API."""
71
+
72
+ lake: DataLake
73
+
74
+ def dataset(self, dataset: str, *, source: str, **kwargs: Any) -> IngestionReport:
75
+ spec = self.lake.datasets.get(dataset, source=source)
76
+ adapter = self.lake.sources.get(source)
77
+ if spec.request_planner == "by_asset" and not kwargs.get("assets"):
78
+ kwargs["assets"] = self._default_assets(source)
79
+ if spec.source == "tushare" and spec.category == "market":
80
+ kwargs["trade_dates"] = self._trade_dates(source, start=kwargs.get("start"), end=kwargs.get("end"))
81
+ context = _request_context(source=source, dataset=dataset, kwargs=kwargs)
82
+ return update_dataset(spec=spec, source_adapter=adapter, pipeline=self.lake._pipeline, context=context)
83
+
84
+ def datasets(self, datasets: list[str], *, source: str, **kwargs: Any) -> UpdateReport:
85
+ reports = [self.dataset(dataset, source=source, **kwargs) for dataset in datasets]
86
+ return combine_reports(source, reports)
87
+
88
+ def source(self, source: str, **kwargs: Any) -> UpdateReport:
89
+ names = [row["name"] for row in self.lake.datasets.list(source) if row["enabled"]]
90
+ return self.datasets(names, source=source, **kwargs)
91
+
92
+ def _default_assets(self, source: str) -> list[str]:
93
+ if source != "tushare":
94
+ raise ConfigurationError("assets=... is required for by_asset dataset updates")
95
+ try:
96
+ frame = self.lake.query.reference("stock_basic", source=source, collect=True)
97
+ except DatasetNotFoundError as exc:
98
+ raise ConfigurationError(
99
+ "Tushare by_asset updates require an asset universe. "
100
+ "Update/register stock_basic first or pass assets=[...] explicitly."
101
+ ) from exc
102
+ if isinstance(frame, pl.LazyFrame):
103
+ frame = frame.collect()
104
+ columns = frame.columns
105
+ column = "asset_id" if "asset_id" in columns else "ts_code" if "ts_code" in columns else None
106
+ if column is None:
107
+ raise ConfigurationError("tushare/stock_basic does not contain asset_id or ts_code")
108
+ return [str(value) for value in frame.get_column(column).drop_nulls().unique().sort().to_list()]
109
+
110
+ def _trade_dates(self, source: str, *, start: Any, end: Any) -> list[str]:
111
+ try:
112
+ frame = self.lake.query.reference("trade_cal", source=source, collect=True)
113
+ except DatasetNotFoundError as exc:
114
+ raise ConfigurationError(
115
+ "Tushare market updates require trade_cal. Update/register trade_cal first."
116
+ ) from exc
117
+ if isinstance(frame, pl.LazyFrame):
118
+ frame = frame.collect()
119
+ if frame.is_empty():
120
+ raise ConfigurationError("tushare/trade_cal is empty; update trade_cal before market datasets")
121
+ column = "time" if "time" in frame.columns else "cal_date" if "cal_date" in frame.columns else None
122
+ if column is None:
123
+ raise ConfigurationError("tushare/trade_cal does not contain time or cal_date")
124
+ dates = frame.with_columns(_calendar_date_expr(column).alias("_calendar_date"))
125
+ if start is not None:
126
+ dates = dates.filter(pl.col("_calendar_date") >= _date_literal(start))
127
+ if end is not None:
128
+ dates = dates.filter(pl.col("_calendar_date") <= _date_literal(end))
129
+ if "is_open" in dates.columns:
130
+ dates = dates.filter(pl.col("is_open").cast(pl.Int8, strict=False) == 1)
131
+ result = [
132
+ value.strftime("%Y-%m-%d")
133
+ for value in dates.select("_calendar_date").drop_nulls().unique().sort("_calendar_date").to_series().to_list()
134
+ ]
135
+ if not result:
136
+ raise ConfigurationError(
137
+ f"tushare/trade_cal has no open trading days between {start} and {end}"
138
+ )
139
+ return result
140
+
141
+
142
+ def _request_context(source: str, dataset: str, kwargs: dict[str, Any]) -> RequestContext:
143
+ known = {
144
+ "start": kwargs.pop("start", None),
145
+ "end": kwargs.pop("end", None),
146
+ "assets": kwargs.pop("assets", None),
147
+ }
148
+ workers = kwargs.pop("workers", None)
149
+ batch_size = kwargs.pop("batch_size", None)
150
+ source_options = kwargs.pop("source_options", None)
151
+ progress = kwargs.pop("progress", None)
152
+ max_retries = kwargs.pop("max_retries", None)
153
+ retry_backoff_seconds = kwargs.pop("retry_backoff_seconds", None)
154
+ trade_dates = kwargs.pop("trade_dates", None)
155
+ if kwargs:
156
+ keys = ", ".join(sorted(kwargs))
157
+ raise ConfigurationError(f"Unsupported update option(s): {keys}")
158
+ options: dict[str, Any] = {}
159
+ if workers is not None:
160
+ options["workers"] = workers
161
+ if batch_size is not None:
162
+ options["batch_size"] = batch_size
163
+ if source_options is not None:
164
+ options["source_options"] = source_options
165
+ if progress is not None:
166
+ options["progress"] = progress
167
+ if max_retries is not None:
168
+ options["max_retries"] = max_retries
169
+ if retry_backoff_seconds is not None:
170
+ options["retry_backoff_seconds"] = retry_backoff_seconds
171
+ if trade_dates is not None:
172
+ options["trade_dates"] = trade_dates
173
+ return RequestContext(source=source, dataset=dataset, options=options, **known)
174
+
175
+
176
+ def _calendar_date_expr(column: str) -> pl.Expr:
177
+ return (
178
+ pl.when(pl.col(column).cast(pl.String).str.len_chars() == 8)
179
+ .then(pl.col(column).cast(pl.String).str.strptime(pl.Date, "%Y%m%d", strict=False))
180
+ .otherwise(pl.col(column).cast(pl.Date, strict=False))
181
+ )
182
+
183
+
184
+ def _date_literal(value: Any) -> pl.Expr:
185
+ if hasattr(value, "strftime"):
186
+ return pl.lit(value).cast(pl.Date, strict=False)
187
+ text = str(value)
188
+ if "T" in text:
189
+ text = text.split("T", maxsplit=1)[0]
190
+ return pl.lit(datetime.strptime(text[:10], "%Y-%m-%d").date())
@@ -0,0 +1,57 @@
1
+ """Source management API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from bagelquant_data.core.exceptions import SourceNotFoundError
8
+ from bagelquant_data.core.registry import FrameworkRegistries
9
+ from bagelquant_data.storage.metadata import MetadataStore
10
+
11
+
12
+ class SourceManager:
13
+ """Register and configure source adapters."""
14
+
15
+ def __init__(self, registries: FrameworkRegistries, metadata: MetadataStore) -> None:
16
+ self.registries = registries
17
+ self.metadata = metadata
18
+
19
+ def register(self, source: object) -> None:
20
+ name = getattr(source, "name")
21
+ if callable(name):
22
+ name = name()
23
+ saved_options = self.metadata.source_options(str(name))
24
+ if saved_options and hasattr(source, "configure"):
25
+ source.configure(**saved_options) # type: ignore[attr-defined]
26
+ self.registries.sources.register(str(name), source)
27
+ self.metadata.upsert_source(
28
+ str(name),
29
+ type(source).__name__,
30
+ configured=bool(saved_options),
31
+ )
32
+
33
+ def remove(self, name: str) -> None:
34
+ self.registries.sources._items.pop(name, None)
35
+ self.metadata.remove_source(name)
36
+
37
+ def list(self) -> list[dict[str, Any]]:
38
+ return self.metadata.list_sources()
39
+
40
+ def get(self, name: str) -> object:
41
+ try:
42
+ return self.registries.sources.get(name)
43
+ except KeyError as exc:
44
+ raise SourceNotFoundError(f"Source is not registered: {name}") from exc
45
+
46
+ def configure(self, name: str, **options: Any) -> None:
47
+ source = self.get(name)
48
+ source.configure(**options) # type: ignore[attr-defined]
49
+ saved = self.metadata.source_options(name)
50
+ saved.update(options)
51
+ self.metadata.upsert_source(name, type(source).__name__, configured=True, options=saved)
52
+
53
+ def configure_tushare(self, token: str) -> None:
54
+ self.configure("tushare", token=token)
55
+
56
+ def test(self, name: str) -> None:
57
+ self.get(name).test_connection() # type: ignore[attr-defined]