jp-idwr-db 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jp_idwr_db/__init__.py +29 -0
- jp_idwr_db/__main__.py +7 -0
- jp_idwr_db/_internal/__init__.py +11 -0
- jp_idwr_db/_internal/download.py +12 -0
- jp_idwr_db/_internal/read.py +12 -0
- jp_idwr_db/_internal/validation.py +133 -0
- jp_idwr_db/api.py +203 -0
- jp_idwr_db/cli.py +38 -0
- jp_idwr_db/config.py +60 -0
- jp_idwr_db/data_manager.py +162 -0
- jp_idwr_db/datasets.py +94 -0
- jp_idwr_db/http.py +231 -0
- jp_idwr_db/io.py +1189 -0
- jp_idwr_db/py.typed +0 -0
- jp_idwr_db/transform.py +145 -0
- jp_idwr_db/types.py +23 -0
- jp_idwr_db/urls.py +228 -0
- jp_idwr_db/utils.py +76 -0
- jp_idwr_db-0.2.2.dist-info/METADATA +243 -0
- jp_idwr_db-0.2.2.dist-info/RECORD +23 -0
- jp_idwr_db-0.2.2.dist-info/WHEEL +4 -0
- jp_idwr_db-0.2.2.dist-info/entry_points.txt +2 -0
- jp_idwr_db-0.2.2.dist-info/licenses/LICENSE +651 -0
jp_idwr_db/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .api import get_data, get_latest_week, list_diseases, list_prefectures
|
|
4
|
+
from .config import Config, configure, get_config
|
|
5
|
+
from .data_manager import ensure_data
|
|
6
|
+
from .datasets import load_dataset as load
|
|
7
|
+
from .transform import merge, pivot
|
|
8
|
+
from .types import DatasetName
|
|
9
|
+
from .utils import attach_prefecture_id, prefecture_map
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Config",
|
|
13
|
+
"DatasetName",
|
|
14
|
+
"attach_prefecture_id",
|
|
15
|
+
"configure",
|
|
16
|
+
"ensure_data",
|
|
17
|
+
"get_config",
|
|
18
|
+
"get_data",
|
|
19
|
+
"get_latest_week",
|
|
20
|
+
"list_diseases",
|
|
21
|
+
"list_prefectures",
|
|
22
|
+
"load",
|
|
23
|
+
"merge",
|
|
24
|
+
"pivot",
|
|
25
|
+
"prefecture_map",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
__version__ = "0.2.2"
|
|
29
|
+
__data_version__ = __version__
|
jp_idwr_db/__main__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Internal utilities for jp_idwr_db package.
|
|
2
|
+
|
|
3
|
+
This module provides low-level functions for downloading and reading data files.
|
|
4
|
+
These functions are used by the build pipeline and are not part of the public API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from . import download, read, validation
|
|
10
|
+
|
|
11
|
+
__all__ = ["download", "read", "validation"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Internal download utilities.
|
|
2
|
+
|
|
3
|
+
This module provides low-level download functionality.
|
|
4
|
+
For simple data access, use the public jp.get_data() API instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
# Re-export download functions from io.py
|
|
10
|
+
from ..io import download, download_recent
|
|
11
|
+
|
|
12
|
+
__all__ = ["download", "download_recent"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Internal read utilities.
|
|
2
|
+
|
|
3
|
+
This module provides low-level file reading functionality.
|
|
4
|
+
For simple data access, use the public jp.get_data() API instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
# Re-export read function from io.py
|
|
10
|
+
from ..io import read
|
|
11
|
+
|
|
12
|
+
__all__ = ["read"]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Validation utilities for data quality checks.
|
|
2
|
+
|
|
3
|
+
This module provides functions for validating data schemas, detecting duplicates,
|
|
4
|
+
and ensuring data quality across different surveillance data sources.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
import polars as pl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_sentinel_only_diseases() -> set[str]:
|
|
15
|
+
"""Get sentinel-only diseases (deprecated static helper).
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Empty set. Sentinel-only detection is now computed dynamically in
|
|
19
|
+
smart_merge() based on disease overlap with zensu data.
|
|
20
|
+
"""
|
|
21
|
+
return set()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def validate_schema(df: pl.DataFrame, required_columns: list[str] | None = None) -> None:
|
|
25
|
+
"""Validate that a DataFrame has the required schema.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
df: DataFrame to validate.
|
|
29
|
+
required_columns: List of required column names. If None, uses standard schema.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: If required columns are missing.
|
|
33
|
+
"""
|
|
34
|
+
if required_columns is None:
|
|
35
|
+
required_columns = ["prefecture", "year", "week", "disease", "count"]
|
|
36
|
+
|
|
37
|
+
missing = [col for col in required_columns if col not in df.columns]
|
|
38
|
+
if missing:
|
|
39
|
+
raise ValueError(f"Missing required columns: {missing}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def validate_no_duplicates(
|
|
43
|
+
df: pl.DataFrame,
|
|
44
|
+
keys: list[str] | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Validate that there are no duplicate records based on key columns.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
df: DataFrame to validate.
|
|
50
|
+
keys: List of column names that define uniqueness. If None, uses
|
|
51
|
+
["prefecture", "year", "week", "disease", "category"].
|
|
52
|
+
Category is included because the same (prefecture, year, week, disease)
|
|
53
|
+
can have multiple categories (e.g., "male", "female", "total").
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If duplicate records are found.
|
|
57
|
+
"""
|
|
58
|
+
if keys is None:
|
|
59
|
+
# Include category if it exists, since same disease can have multiple categories
|
|
60
|
+
keys = ["prefecture", "year", "week", "disease"]
|
|
61
|
+
if "category" in df.columns:
|
|
62
|
+
keys.append("category")
|
|
63
|
+
|
|
64
|
+
# Count occurrences of each unique combination
|
|
65
|
+
dups = df.group_by(keys).agg(pl.len().alias("count")).filter(pl.col("count") > 1)
|
|
66
|
+
|
|
67
|
+
if dups.height > 0:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Found {dups.height} duplicate records. First few duplicates:\n{dups.head(5)}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def validate_date_ranges(df: pl.DataFrame) -> None:
|
|
74
|
+
"""Validate that year and week values are reasonable.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
df: DataFrame to validate.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If year or week values are out of expected ranges.
|
|
81
|
+
"""
|
|
82
|
+
if "year" in df.columns:
|
|
83
|
+
years = df["year"]
|
|
84
|
+
min_year_raw, max_year_raw = years.min(), years.max()
|
|
85
|
+
min_year = cast(int, min_year_raw)
|
|
86
|
+
max_year = cast(int, max_year_raw)
|
|
87
|
+
if min_year < 1999 or max_year > 2030:
|
|
88
|
+
raise ValueError(f"Year values out of expected range: {min_year}-{max_year}")
|
|
89
|
+
|
|
90
|
+
if "week" in df.columns:
|
|
91
|
+
weeks = df["week"]
|
|
92
|
+
min_week_raw, max_week_raw = weeks.min(), weeks.max()
|
|
93
|
+
min_week = cast(int, min_week_raw)
|
|
94
|
+
max_week = cast(int, max_week_raw)
|
|
95
|
+
if min_week < 1 or max_week > 53:
|
|
96
|
+
raise ValueError(f"Week values out of valid range: {min_week}-{max_week}")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def smart_merge(
|
|
100
|
+
zensu_df: pl.DataFrame,
|
|
101
|
+
teiten_df: pl.DataFrame,
|
|
102
|
+
) -> pl.DataFrame:
|
|
103
|
+
"""Merge zensu and teiten data, preferring confirmed (zensu) data.
|
|
104
|
+
|
|
105
|
+
This function implements the "prefer confirmed" strategy:
|
|
106
|
+
- Keep ALL zensu (confirmed case) data
|
|
107
|
+
- Add ONLY sentinel diseases that are absent from zensu
|
|
108
|
+
- This avoids duplication while preserving diseases only in sentinel surveillance
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
zensu_df: Confirmed case data (from zensu/bullet files).
|
|
112
|
+
teiten_df: Sentinel surveillance data (from teiten files).
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Merged DataFrame with no duplicate diseases.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> zensu = pl.DataFrame({"disease": ["Influenza", "Tuberculosis"], "count": [100, 10]})
|
|
119
|
+
>>> teiten = pl.DataFrame({"disease": ["Influenza", "RSV"], "count": [120, 50]})
|
|
120
|
+
>>> merged = smart_merge(zensu, teiten)
|
|
121
|
+
>>> # Result: Influenza from zensu + RSV from teiten
|
|
122
|
+
"""
|
|
123
|
+
confirmed_diseases = (
|
|
124
|
+
zensu_df.select("disease").drop_nulls().unique().get_column("disease").to_list()
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Filter teiten to only include diseases not present in confirmed data.
|
|
128
|
+
teiten_filtered = teiten_df.filter(~pl.col("disease").is_in(confirmed_diseases))
|
|
129
|
+
|
|
130
|
+
# Combine zensu (all diseases) + teiten (sentinel-only diseases)
|
|
131
|
+
merged = pl.concat([zensu_df, teiten_filtered], how="diagonal_relaxed")
|
|
132
|
+
|
|
133
|
+
return merged
|
jp_idwr_db/api.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""New unified API for accessing infectious disease surveillance data.
|
|
2
|
+
|
|
3
|
+
This module provides the main user-facing API for jp_idwr_db, offering
|
|
4
|
+
simple data access with flexible filtering capabilities.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
from .datasets import load_dataset
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_data(
|
|
20
|
+
disease: str | list[str] | None = None,
|
|
21
|
+
prefecture: str | list[str] | None = None,
|
|
22
|
+
year: int | tuple[int, int] | None = None,
|
|
23
|
+
week: int | tuple[int, int] | None = None,
|
|
24
|
+
source: Literal["confirmed", "sentinel", "all"] = "all",
|
|
25
|
+
) -> pl.DataFrame:
|
|
26
|
+
"""Get infectious disease surveillance data with optional filtering.
|
|
27
|
+
|
|
28
|
+
This is the main entry point for accessing jp_idwr_db data. It loads all
|
|
29
|
+
available data (historical + recent, confirmed + sentinel) and applies
|
|
30
|
+
optional filters.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
disease: Filter by disease name(s). Case-insensitive partial matching.
|
|
34
|
+
Examples: "Influenza", ["COVID-19", "Influenza"], "RS virus"
|
|
35
|
+
prefecture: Filter by prefecture name(s).
|
|
36
|
+
Examples: "Tokyo", ["Tokyo", "Osaka"]
|
|
37
|
+
year: Filter by single year or (start, end) range (inclusive).
|
|
38
|
+
Examples: 2024, (2020, 2024)
|
|
39
|
+
week: Filter by single week or (start, end) range (inclusive).
|
|
40
|
+
Examples: 10, (1, 52)
|
|
41
|
+
source: Data source filter.
|
|
42
|
+
- "confirmed": Only zensu (all-case reporting) data
|
|
43
|
+
- "sentinel": Only teiten (sentinel surveillance) data
|
|
44
|
+
- "all": Both sources (default)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
DataFrame with standardized schema containing:
|
|
48
|
+
- prefecture: Prefecture name
|
|
49
|
+
- year: ISO year
|
|
50
|
+
- week: ISO week
|
|
51
|
+
- date: Week start date
|
|
52
|
+
- disease: Disease name (normalized)
|
|
53
|
+
- count: Weekly case count
|
|
54
|
+
- per_sentinel: Per-sentinel rate (sentinel only, null for confirmed)
|
|
55
|
+
- source: "Confirmed cases" or "Sentinel surveillance"
|
|
56
|
+
- category: "total", "male", "female" (when available)
|
|
57
|
+
|
|
58
|
+
Examples:
|
|
59
|
+
>>> import jp_idwr_db as jp
|
|
60
|
+
>>> # Get all data
|
|
61
|
+
>>> df = jp.get_data()
|
|
62
|
+
|
|
63
|
+
>>> # Filter by disease
|
|
64
|
+
>>> flu = jp.get_data(disease="Influenza")
|
|
65
|
+
|
|
66
|
+
>>> # Multiple diseases, specific year
|
|
67
|
+
>>> df = jp.get_data(disease=["COVID-19", "Influenza"], year=2024)
|
|
68
|
+
|
|
69
|
+
>>> # Prefecture and year range
|
|
70
|
+
>>> tokyo = jp.get_data(prefecture="Tokyo", year=(2020, 2024))
|
|
71
|
+
|
|
72
|
+
>>> # Only sentinel data
|
|
73
|
+
>>> sentinel = jp.get_data(source="sentinel", year=2024)
|
|
74
|
+
|
|
75
|
+
>>> # Complex filtering
|
|
76
|
+
>>> df = jp.get_data(
|
|
77
|
+
... disease=["Influenza", "RS virus"],
|
|
78
|
+
... prefecture=["Tokyo", "Osaka"],
|
|
79
|
+
... year=(2023, 2025),
|
|
80
|
+
... source="all"
|
|
81
|
+
... )
|
|
82
|
+
"""
|
|
83
|
+
# Load unified dataset (cached locally, downloaded from releases on demand).
|
|
84
|
+
try:
|
|
85
|
+
df = load_dataset("unified")
|
|
86
|
+
except Exception:
|
|
87
|
+
logger.warning("Failed to load unified dataset, falling back to bullet dataset")
|
|
88
|
+
try:
|
|
89
|
+
df = load_dataset("bullet")
|
|
90
|
+
except Exception:
|
|
91
|
+
logger.warning("Failed to load bullet dataset, returning empty DataFrame")
|
|
92
|
+
df = pl.DataFrame()
|
|
93
|
+
|
|
94
|
+
if df.height == 0:
|
|
95
|
+
return df
|
|
96
|
+
|
|
97
|
+
# Apply filters
|
|
98
|
+
if source != "all" and "source" in df.columns:
|
|
99
|
+
source_map = {
|
|
100
|
+
"confirmed": ["Confirmed cases", "All-case reporting"],
|
|
101
|
+
"sentinel": "Sentinel surveillance",
|
|
102
|
+
}
|
|
103
|
+
if source in source_map:
|
|
104
|
+
target = source_map[source]
|
|
105
|
+
if isinstance(target, list):
|
|
106
|
+
df = df.filter(pl.col("source").is_in(target))
|
|
107
|
+
else:
|
|
108
|
+
df = df.filter(pl.col("source") == target)
|
|
109
|
+
|
|
110
|
+
if disease is not None:
|
|
111
|
+
diseases = [disease] if isinstance(disease, str) else disease
|
|
112
|
+
# Case-insensitive partial matching
|
|
113
|
+
disease_filter = pl.lit(False)
|
|
114
|
+
for d in diseases:
|
|
115
|
+
disease_filter = disease_filter | pl.col("disease").str.to_lowercase().str.contains(
|
|
116
|
+
d.lower()
|
|
117
|
+
)
|
|
118
|
+
df = df.filter(disease_filter)
|
|
119
|
+
|
|
120
|
+
if prefecture is not None:
|
|
121
|
+
prefectures = [prefecture] if isinstance(prefecture, str) else prefecture
|
|
122
|
+
df = df.filter(pl.col("prefecture").is_in(prefectures))
|
|
123
|
+
|
|
124
|
+
if year is not None:
|
|
125
|
+
if isinstance(year, tuple):
|
|
126
|
+
start_year, end_year = year
|
|
127
|
+
df = df.filter((pl.col("year") >= start_year) & (pl.col("year") <= end_year))
|
|
128
|
+
else:
|
|
129
|
+
df = df.filter(pl.col("year") == year)
|
|
130
|
+
|
|
131
|
+
if week is not None:
|
|
132
|
+
if isinstance(week, tuple):
|
|
133
|
+
start_week, end_week = week
|
|
134
|
+
df = df.filter((pl.col("week") >= start_week) & (pl.col("week") <= end_week))
|
|
135
|
+
else:
|
|
136
|
+
df = df.filter(pl.col("week") == week)
|
|
137
|
+
|
|
138
|
+
return df
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def list_diseases(source: Literal["confirmed", "sentinel", "all"] = "all") -> list[str]:
|
|
142
|
+
"""Get list of available disease names.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
source: Filter by data source - "confirmed", "sentinel", or "all".
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Sorted list of disease names.
|
|
149
|
+
|
|
150
|
+
Example:
|
|
151
|
+
>>> import jp_idwr_db as jp
|
|
152
|
+
>>> all_diseases = jp.list_diseases()
|
|
153
|
+
>>> sentinel_only = jp.list_diseases(source="sentinel")
|
|
154
|
+
"""
|
|
155
|
+
df = get_data(source=source)
|
|
156
|
+
if df.height == 0:
|
|
157
|
+
return []
|
|
158
|
+
return sorted(df["disease"].unique().to_list())
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def list_prefectures() -> list[str]:
|
|
162
|
+
"""Get list of prefecture names.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Sorted list of prefecture names.
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
>>> import jp_idwr_db as jp
|
|
169
|
+
>>> prefectures = jp.list_prefectures()
|
|
170
|
+
>>> print(prefectures[:3])
|
|
171
|
+
['Aichi', 'Akita', 'Aomori']
|
|
172
|
+
"""
|
|
173
|
+
df = get_data()
|
|
174
|
+
if df.height == 0:
|
|
175
|
+
return []
|
|
176
|
+
return sorted(df["prefecture"].unique().to_list())
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_latest_week() -> tuple[int, int] | None:
|
|
180
|
+
"""Get the latest (year, week) with data available.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Tuple of (year, week) for the most recent data, or None if no data.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> import jp_idwr_db as jp
|
|
187
|
+
>>> latest = jp.get_latest_week()
|
|
188
|
+
>>> if latest:
|
|
189
|
+
... year, week = latest
|
|
190
|
+
... print(f"Latest data: {year} week {week}")
|
|
191
|
+
"""
|
|
192
|
+
df = get_data()
|
|
193
|
+
if df.height == 0:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
# Check if year column exists, otherwise we can't determine the latest week
|
|
197
|
+
if "year" not in df.columns or "week" not in df.columns:
|
|
198
|
+
logger.warning("Cannot determine latest week: missing year or week column")
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
# Get row with maximum year, then maximum week within that year
|
|
202
|
+
latest = df.sort(["year", "week"], descending=True).head(1)
|
|
203
|
+
return (int(latest["year"][0]), int(latest["week"][0]))
|
jp_idwr_db/cli.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Command-line interface for jp_idwr_db."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
|
|
7
|
+
from .data_manager import ensure_data
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
11
|
+
"""Build the top-level CLI argument parser."""
|
|
12
|
+
parser = argparse.ArgumentParser(prog="jp-idwr-db")
|
|
13
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
14
|
+
|
|
15
|
+
data_parser = subparsers.add_parser("data", help="Manage local data cache")
|
|
16
|
+
data_subparsers = data_parser.add_subparsers(dest="data_command")
|
|
17
|
+
|
|
18
|
+
download_parser = data_subparsers.add_parser("download", help="Download release parquet assets")
|
|
19
|
+
download_parser.add_argument(
|
|
20
|
+
"--version", type=str, default=None, help="Data version (e.g. v0.1.0)"
|
|
21
|
+
)
|
|
22
|
+
download_parser.add_argument("--force", action="store_true", help="Force re-download")
|
|
23
|
+
|
|
24
|
+
return parser
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main(argv: list[str] | None = None) -> int:
|
|
28
|
+
"""Run CLI entrypoint."""
|
|
29
|
+
parser = build_parser()
|
|
30
|
+
args = parser.parse_args(argv)
|
|
31
|
+
|
|
32
|
+
if args.command == "data" and args.data_command == "download":
|
|
33
|
+
data_dir = ensure_data(version=args.version, force=args.force)
|
|
34
|
+
print(data_dir)
|
|
35
|
+
return 0
|
|
36
|
+
|
|
37
|
+
parser.print_help()
|
|
38
|
+
return 0
|
jp_idwr_db/config.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Configuration management for jp_idwr_db.
|
|
2
|
+
|
|
3
|
+
This module provides a global configuration system for controlling package behavior.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, replace
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from platformdirs import user_cache_dir
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class Config:
|
|
16
|
+
"""Global configuration for jp_idwr_db package.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
cache_dir: Directory for caching downloaded files.
|
|
20
|
+
rate_limit_per_minute: Maximum number of HTTP requests per minute.
|
|
21
|
+
user_agent: User-Agent header for HTTP requests.
|
|
22
|
+
timeout_seconds: Timeout for HTTP requests in seconds.
|
|
23
|
+
retries: Number of retry attempts for failed requests.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
cache_dir: Path = Path(user_cache_dir("jp_idwr_db"))
|
|
27
|
+
rate_limit_per_minute: int = 20
|
|
28
|
+
user_agent: str = "jp_idwr_db/0.2.2 (+https://github.com/AlFontal/jp-idwr-db)"
|
|
29
|
+
timeout_seconds: float = 30.0
|
|
30
|
+
retries: int = 3
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_CONFIG = Config()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_config() -> Config:
|
|
37
|
+
"""Get the current global configuration.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The current Config instance.
|
|
41
|
+
"""
|
|
42
|
+
return _CONFIG
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def configure(**kwargs: object) -> Config:
|
|
46
|
+
"""Update the global configuration.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
**kwargs: Configuration parameters to update (see Config attributes).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The updated Config instance.
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
>>> import jp_idwr_db as jp
|
|
56
|
+
>>> jp.configure(rate_limit_per_minute=10)
|
|
57
|
+
"""
|
|
58
|
+
global _CONFIG # noqa: PLW0603
|
|
59
|
+
_CONFIG = replace(_CONFIG, **kwargs) # type: ignore[arg-type]
|
|
60
|
+
return _CONFIG
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Runtime dataset manager for release-hosted parquet assets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import shutil
|
|
9
|
+
import zipfile
|
|
10
|
+
from importlib.metadata import PackageNotFoundError
|
|
11
|
+
from importlib.metadata import version as package_version
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from platformdirs import user_cache_dir
|
|
17
|
+
|
|
18
|
+
PACKAGE_NAME = "jp_idwr_db"
|
|
19
|
+
DEFAULT_REPO = "AlFontal/jp-idwr-db"
|
|
20
|
+
DEFAULT_BASE_URL = f"https://github.com/{DEFAULT_REPO}/releases/download"
|
|
21
|
+
ARCHIVE_NAME = "jp_idwr_db-parquet.zip"
|
|
22
|
+
MANIFEST_NAME = "jp_idwr_db-manifest.json"
|
|
23
|
+
EXPECTED_DATASETS = {
|
|
24
|
+
"sex_prefecture.parquet",
|
|
25
|
+
"place_prefecture.parquet",
|
|
26
|
+
"bullet.parquet",
|
|
27
|
+
"sentinel.parquet",
|
|
28
|
+
"unified.parquet",
|
|
29
|
+
"prefecture_en.parquet",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_cache_dir() -> Path:
|
|
34
|
+
"""Return the base cache directory for package data."""
|
|
35
|
+
override = os.getenv("JPINFECT_CACHE_DIR")
|
|
36
|
+
if override:
|
|
37
|
+
return Path(override).expanduser()
|
|
38
|
+
return Path(user_cache_dir(PACKAGE_NAME))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _resolve_data_version(version: str | None) -> str:
|
|
42
|
+
"""Resolve data version from explicit arg, env var, or package version."""
|
|
43
|
+
if version:
|
|
44
|
+
return version
|
|
45
|
+
env_version = os.getenv("JPINFECT_DATA_VERSION")
|
|
46
|
+
if env_version:
|
|
47
|
+
return env_version
|
|
48
|
+
try:
|
|
49
|
+
pkg_version = package_version("jp-idwr-db")
|
|
50
|
+
except PackageNotFoundError:
|
|
51
|
+
pkg_version = "0.0.0"
|
|
52
|
+
return pkg_version if pkg_version.startswith("v") else f"v{pkg_version}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _resolve_base_url(version: str) -> str:
|
|
56
|
+
"""Resolve base URL for release assets."""
|
|
57
|
+
base_url = os.getenv("JPINFECT_DATA_BASE_URL")
|
|
58
|
+
if base_url:
|
|
59
|
+
return base_url.rstrip("/")
|
|
60
|
+
return f"{DEFAULT_BASE_URL}/{version}"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _sha256(path: Path) -> str:
|
|
64
|
+
"""Compute SHA256 hash for a file path."""
|
|
65
|
+
digest = hashlib.sha256()
|
|
66
|
+
with path.open("rb") as handle:
|
|
67
|
+
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
|
|
68
|
+
digest.update(chunk)
|
|
69
|
+
return digest.hexdigest()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _download_file(url: str, dest: Path) -> None:
|
|
73
|
+
"""Download URL content to a local path."""
|
|
74
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
with httpx.stream("GET", url, timeout=60.0, follow_redirects=True) as response:
|
|
76
|
+
response.raise_for_status()
|
|
77
|
+
with dest.open("wb") as handle:
|
|
78
|
+
for chunk in response.iter_bytes():
|
|
79
|
+
handle.write(chunk)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _verify_manifest(manifest: dict[str, Any]) -> None:
|
|
83
|
+
"""Ensure manifest has expected structure."""
|
|
84
|
+
required = {"archive", "archive_sha256", "files"}
|
|
85
|
+
missing = required - set(manifest)
|
|
86
|
+
if missing:
|
|
87
|
+
raise ValueError(f"Invalid manifest; missing keys: {sorted(missing)}")
|
|
88
|
+
files = manifest["files"]
|
|
89
|
+
if not isinstance(files, dict) or not files:
|
|
90
|
+
raise ValueError("Invalid manifest; 'files' must be a non-empty object")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def download_release_assets(version: str, dest_dir: Path) -> tuple[Path, Path]:
|
|
94
|
+
"""Download release archive + manifest for a data version."""
|
|
95
|
+
base_url = _resolve_base_url(version)
|
|
96
|
+
manifest_path = dest_dir / MANIFEST_NAME
|
|
97
|
+
archive_path = dest_dir / ARCHIVE_NAME
|
|
98
|
+
|
|
99
|
+
_download_file(f"{base_url}/{MANIFEST_NAME}", manifest_path)
|
|
100
|
+
_download_file(f"{base_url}/{ARCHIVE_NAME}", archive_path)
|
|
101
|
+
return archive_path, manifest_path
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_archive(archive_path: Path, dest_dir: Path) -> None:
|
|
105
|
+
"""Extract archive into destination directory."""
|
|
106
|
+
with zipfile.ZipFile(archive_path) as archive:
|
|
107
|
+
archive.extractall(dest_dir)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def ensure_data(version: str | None = None, force: bool = False) -> Path:
|
|
111
|
+
"""Ensure parquet assets are available in the local cache.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
version: Data release version (for example, ``v0.1.0``).
|
|
115
|
+
force: Re-download and replace cached files.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Directory path containing parquet files for the resolved version.
|
|
119
|
+
"""
|
|
120
|
+
resolved = _resolve_data_version(version)
|
|
121
|
+
cache_dir = get_cache_dir()
|
|
122
|
+
data_dir = cache_dir / "data" / resolved
|
|
123
|
+
marker = data_dir / ".complete"
|
|
124
|
+
|
|
125
|
+
if marker.exists() and not force:
|
|
126
|
+
return data_dir
|
|
127
|
+
|
|
128
|
+
if force and data_dir.exists():
|
|
129
|
+
shutil.rmtree(data_dir)
|
|
130
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
131
|
+
|
|
132
|
+
archive_path, manifest_path = download_release_assets(resolved, data_dir)
|
|
133
|
+
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
|
134
|
+
_verify_manifest(manifest)
|
|
135
|
+
|
|
136
|
+
expected_archive = str(manifest["archive"])
|
|
137
|
+
if archive_path.name != expected_archive:
|
|
138
|
+
expected_path = data_dir / expected_archive
|
|
139
|
+
archive_path.rename(expected_path)
|
|
140
|
+
archive_path = expected_path
|
|
141
|
+
|
|
142
|
+
archive_hash = _sha256(archive_path)
|
|
143
|
+
if archive_hash != manifest["archive_sha256"]:
|
|
144
|
+
raise ValueError("Archive checksum mismatch")
|
|
145
|
+
|
|
146
|
+
_extract_archive(archive_path, data_dir)
|
|
147
|
+
|
|
148
|
+
file_entries: dict[str, dict[str, Any]] = manifest["files"]
|
|
149
|
+
for rel_name, file_info in file_entries.items():
|
|
150
|
+
file_path = data_dir / rel_name
|
|
151
|
+
if not file_path.exists():
|
|
152
|
+
raise ValueError(f"Missing extracted data file: {rel_name}")
|
|
153
|
+
expected_hash = str(file_info["sha256"])
|
|
154
|
+
if _sha256(file_path) != expected_hash:
|
|
155
|
+
raise ValueError(f"Checksum mismatch for {rel_name}")
|
|
156
|
+
|
|
157
|
+
missing_expected = [name for name in EXPECTED_DATASETS if not (data_dir / name).exists()]
|
|
158
|
+
if missing_expected:
|
|
159
|
+
raise ValueError(f"Missing required datasets in cache: {sorted(missing_expected)}")
|
|
160
|
+
|
|
161
|
+
marker.write_text("ok\n", encoding="utf-8")
|
|
162
|
+
return data_dir
|