retaildata 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retaildata-0.1.0/PKG-INFO +123 -0
- retaildata-0.1.0/README.md +84 -0
- retaildata-0.1.0/pyproject.toml +46 -0
- retaildata-0.1.0/retaildata/api.py +202 -0
- retaildata-0.1.0/retaildata/cache/manager.py +71 -0
- retaildata-0.1.0/retaildata/cli.py +319 -0
- retaildata-0.1.0/retaildata/config.py +30 -0
- retaildata-0.1.0/retaildata/credentials/encrypted_file_store.py +81 -0
- retaildata-0.1.0/retaildata/credentials/keyring_store.py +16 -0
- retaildata-0.1.0/retaildata/credentials/manager.py +55 -0
- retaildata-0.1.0/retaildata/credentials/store.py +18 -0
- retaildata-0.1.0/retaildata/datasets/registry.py +258 -0
- retaildata-0.1.0/retaildata/main.py +7 -0
- retaildata-0.1.0/retaildata/postprocess/metadata.py +47 -0
- retaildata-0.1.0/retaildata/providers/base.py +23 -0
- retaildata-0.1.0/retaildata/providers/http.py +68 -0
- retaildata-0.1.0/retaildata/providers/kaggle.py +84 -0
- retaildata-0.1.0/retaildata.egg-info/PKG-INFO +123 -0
- retaildata-0.1.0/retaildata.egg-info/SOURCES.txt +23 -0
- retaildata-0.1.0/retaildata.egg-info/dependency_links.txt +1 -0
- retaildata-0.1.0/retaildata.egg-info/entry_points.txt +2 -0
- retaildata-0.1.0/retaildata.egg-info/requires.txt +36 -0
- retaildata-0.1.0/retaildata.egg-info/top_level.txt +1 -0
- retaildata-0.1.0/setup.cfg +4 -0
- retaildata-0.1.0/tests/test_ml.py +35 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: retaildata
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Secure CLI + Python API to download and manage real-world retail benchmark datasets.
|
|
5
|
+
Author-email: Gwang-Jin Kim <gwang.jin.kim.phd@gmail.com>
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: argon2-cffi>=25.1.0
|
|
9
|
+
Requires-Dist: cryptography>=46.0.5
|
|
10
|
+
Requires-Dist: httpx>=0.28.1
|
|
11
|
+
Requires-Dist: kaggle>=1.7.4.5
|
|
12
|
+
Requires-Dist: keyring>=25.7.0
|
|
13
|
+
Requires-Dist: platformdirs>=4.4.0
|
|
14
|
+
Requires-Dist: pydantic>=2.12.5
|
|
15
|
+
Requires-Dist: pydantic-settings>=2.0.0
|
|
16
|
+
Requires-Dist: rich>=14.3.2
|
|
17
|
+
Requires-Dist: tqdm>=4.67.3
|
|
18
|
+
Requires-Dist: typer>=0.23.0
|
|
19
|
+
Requires-Dist: polars>=0.20.0
|
|
20
|
+
Requires-Dist: huggingface_hub>=0.20.0
|
|
21
|
+
Requires-Dist: ucimlrepo
|
|
22
|
+
Requires-Dist: openml
|
|
23
|
+
Requires-Dist: numpy>=1.20.0
|
|
24
|
+
Provides-Extra: torch
|
|
25
|
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
|
26
|
+
Provides-Extra: tf
|
|
27
|
+
Requires-Dist: tensorflow>=2.10.0; extra == "tf"
|
|
28
|
+
Provides-Extra: jax
|
|
29
|
+
Requires-Dist: jax>=0.4.0; extra == "jax"
|
|
30
|
+
Requires-Dist: jaxlib>=0.4.0; extra == "jax"
|
|
31
|
+
Provides-Extra: dlt
|
|
32
|
+
Requires-Dist: dlt[duckdb]>=0.3.0; extra == "dlt"
|
|
33
|
+
Provides-Extra: all
|
|
34
|
+
Requires-Dist: torch>=2.0.0; extra == "all"
|
|
35
|
+
Requires-Dist: tensorflow>=2.10.0; extra == "all"
|
|
36
|
+
Requires-Dist: jax>=0.4.0; extra == "all"
|
|
37
|
+
Requires-Dist: jaxlib>=0.4.0; extra == "all"
|
|
38
|
+
Requires-Dist: dlt[duckdb]>=0.3.0; extra == "all"
|
|
39
|
+
|
|
40
|
+
# RetailData
|
|
41
|
+
|
|
42
|
+
A unified interface for fetching and preparing retail datasets for benchmarking and analysis.
|
|
43
|
+
|
|
44
|
+
## Features
|
|
45
|
+
|
|
46
|
+
- **Unified API**: Fetch datasets from various providers (HTTP, Kaggle, Hugging Face, UCI, OpenML) with a single command.
|
|
47
|
+
- **Secure Credentials**: Integrated support for Kaggle and Hugging Face API keys.
|
|
48
|
+
- **Data Benchmark Pack**: Curated retail datasets (Favorita, Rossmann, Instacart, M5, Olist, and more).
|
|
49
|
+
- **Processing Pipeline**: Automatic conversion to high-performance Parquet optimized for Polars.
|
|
50
|
+
- **Cache Management**: Programmatic disk usage tracking and clearing.
|
|
51
|
+
|
|
52
|
+
## Installation
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
pip install retaildata
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Or using `uv` (recommended for development):
|
|
59
|
+
```bash
|
|
60
|
+
uv pip install -e .
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Quick Start
|
|
64
|
+
|
|
65
|
+
### CLI
|
|
66
|
+
|
|
67
|
+
1. **List available datasets**:
|
|
68
|
+
```bash
|
|
69
|
+
retaildata list
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
2. **Download a dataset**:
|
|
73
|
+
```bash
|
|
74
|
+
retaildata get test_http
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
3. **Download with Preparation (Parquet)**:
|
|
78
|
+
```bash
|
|
79
|
+
retaildata get online_retail_ii --prepare
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
4. **Manage Credentials (e.g. Kaggle)**:
|
|
83
|
+
```bash
|
|
84
|
+
retaildata auth set kaggle --file ~/.kaggle/kaggle.json
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
5. **Clean Up**:
|
|
88
|
+
```bash
|
|
89
|
+
retaildata rm test_http
|
|
90
|
+
retaildata purge --all
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### Python API
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
import retaildata.api as rd
|
|
97
|
+
import polars as pl
|
|
98
|
+
from pathlib import Path
|
|
99
|
+
|
|
100
|
+
# Download and prepare dataset
|
|
101
|
+
rd.api.download("online_retail_ii", prepare=True)
|
|
102
|
+
|
|
103
|
+
# Load efficiently with Polars
|
|
104
|
+
df = pl.scan_parquet("~/.local/share/retaildata/prepared/online_retail_ii/*.parquet").collect()
|
|
105
|
+
print(df.head())
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
## Supported Datasets
|
|
109
|
+
|
|
110
|
+
- `online_retail_ii`: UK-based online retail transactions.
|
|
111
|
+
- `olist`: Brazilian e-commerce dataset.
|
|
112
|
+
- `m5`: Walmart time-series forecasting.
|
|
113
|
+
- `store_sales`: Corporación Favorita (Ecuador) store sales.
|
|
114
|
+
- `rossmann`: Rossmann store sales benchmarks.
|
|
115
|
+
- `instacart`: Online grocery basket analysis.
|
|
116
|
+
- `online_retail_uci`: Classical transactions dataset (UCI).
|
|
117
|
+
- `credit_approval_openml`: Financial benchmarking (OpenML).
|
|
118
|
+
|
|
119
|
+
See `retaildata list` for the full registry.
|
|
120
|
+
|
|
121
|
+
## License
|
|
122
|
+
|
|
123
|
+
This package is licensed under the MIT License. Individual datasets may have their own licenses.
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# RetailData
|
|
2
|
+
|
|
3
|
+
A unified interface for fetching and preparing retail datasets for benchmarking and analysis.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- **Unified API**: Fetch datasets from various providers (HTTP, Kaggle, Hugging Face, UCI, OpenML) with a single command.
|
|
8
|
+
- **Secure Credentials**: Integrated support for Kaggle and Hugging Face API keys.
|
|
9
|
+
- **Data Benchmark Pack**: Curated retail datasets (Favorita, Rossmann, Instacart, M5, Olist, and more).
|
|
10
|
+
- **Processing Pipeline**: Automatic conversion to high-performance Parquet optimized for Polars.
|
|
11
|
+
- **Cache Management**: Programmatic disk usage tracking and clearing.
|
|
12
|
+
|
|
13
|
+
## Installation
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
pip install retaildata
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Or using `uv` (recommended for development):
|
|
20
|
+
```bash
|
|
21
|
+
uv pip install -e .
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Quick Start
|
|
25
|
+
|
|
26
|
+
### CLI
|
|
27
|
+
|
|
28
|
+
1. **List available datasets**:
|
|
29
|
+
```bash
|
|
30
|
+
retaildata list
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
2. **Download a dataset**:
|
|
34
|
+
```bash
|
|
35
|
+
retaildata get test_http
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
3. **Download with Preparation (Parquet)**:
|
|
39
|
+
```bash
|
|
40
|
+
retaildata get online_retail_ii --prepare
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
4. **Manage Credentials (e.g. Kaggle)**:
|
|
44
|
+
```bash
|
|
45
|
+
retaildata auth set kaggle --file ~/.kaggle/kaggle.json
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
5. **Clean Up**:
|
|
49
|
+
```bash
|
|
50
|
+
retaildata rm test_http
|
|
51
|
+
retaildata purge --all
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
### Python API
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
import retaildata.api as rd
|
|
58
|
+
import polars as pl
|
|
59
|
+
from pathlib import Path
|
|
60
|
+
|
|
61
|
+
# Download and prepare dataset
|
|
62
|
+
rd.api.download("online_retail_ii", prepare=True)
|
|
63
|
+
|
|
64
|
+
# Load efficiently with Polars
|
|
65
|
+
df = pl.scan_parquet("~/.local/share/retaildata/prepared/online_retail_ii/*.parquet").collect()
|
|
66
|
+
print(df.head())
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Supported Datasets
|
|
70
|
+
|
|
71
|
+
- `online_retail_ii`: UK-based online retail transactions.
|
|
72
|
+
- `olist`: Brazilian e-commerce dataset.
|
|
73
|
+
- `m5`: Walmart time-series forecasting.
|
|
74
|
+
- `store_sales`: Corporación Favorita (Ecuador) store sales.
|
|
75
|
+
- `rossmann`: Rossmann store sales benchmarks.
|
|
76
|
+
- `instacart`: Online grocery basket analysis.
|
|
77
|
+
- `online_retail_uci`: Classical transactions dataset (UCI).
|
|
78
|
+
- `credit_approval_openml`: Financial benchmarking (OpenML).
|
|
79
|
+
|
|
80
|
+
See `retaildata list` for the full registry.
|
|
81
|
+
|
|
82
|
+
## License
|
|
83
|
+
|
|
84
|
+
This package is licensed under the MIT License. Individual datasets may have their own licenses.
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "retaildata"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Secure CLI + Python API to download and manage real-world retail benchmark datasets."
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "Gwang-Jin Kim", email = "gwang.jin.kim.phd@gmail.com"}
|
|
7
|
+
]
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
requires-python = ">=3.9"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"argon2-cffi>=25.1.0",
|
|
12
|
+
"cryptography>=46.0.5",
|
|
13
|
+
"httpx>=0.28.1",
|
|
14
|
+
"kaggle>=1.7.4.5",
|
|
15
|
+
"keyring>=25.7.0",
|
|
16
|
+
"platformdirs>=4.4.0",
|
|
17
|
+
"pydantic>=2.12.5",
|
|
18
|
+
"pydantic-settings>=2.0.0",
|
|
19
|
+
"rich>=14.3.2",
|
|
20
|
+
"tqdm>=4.67.3",
|
|
21
|
+
"typer>=0.23.0",
|
|
22
|
+
"polars>=0.20.0",
|
|
23
|
+
"huggingface_hub>=0.20.0",
|
|
24
|
+
"ucimlrepo",
|
|
25
|
+
"openml",
|
|
26
|
+
"numpy>=1.20.0",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[project.optional-dependencies]
|
|
30
|
+
torch = ["torch>=2.0.0"]
|
|
31
|
+
tf = ["tensorflow>=2.10.0"]
|
|
32
|
+
jax = ["jax>=0.4.0", "jaxlib>=0.4.0"]
|
|
33
|
+
dlt = ["dlt[duckdb]>=0.3.0"]
|
|
34
|
+
all = ["torch>=2.0.0", "tensorflow>=2.10.0", "jax>=0.4.0", "jaxlib>=0.4.0", "dlt[duckdb]>=0.3.0"]
|
|
35
|
+
|
|
36
|
+
[dependency-groups]
|
|
37
|
+
dev = [
|
|
38
|
+
"pytest>=7.0.0",
|
|
39
|
+
"pytest-mock>=3.10.0",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[project.scripts]
|
|
43
|
+
retaildata = "retaildata.main:main"
|
|
44
|
+
|
|
45
|
+
[tool.setuptools]
|
|
46
|
+
packages = ["retaildata"]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Optional, List, Any, Dict
|
|
3
|
+
from retaildata.datasets.registry import Registry, Dataset
|
|
4
|
+
from retaildata.providers.http import HTTPProvider
|
|
5
|
+
from retaildata.providers.kaggle import KaggleProvider
|
|
6
|
+
from retaildata.config import settings
|
|
7
|
+
from rich import print as rprint
|
|
8
|
+
|
|
9
|
+
class RetailDataAPI:
|
|
10
|
+
def list_datasets(self) -> List[Dataset]:
|
|
11
|
+
"""Lists all available datasets."""
|
|
12
|
+
return Registry.list_all()
|
|
13
|
+
|
|
14
|
+
def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
|
15
|
+
"""Gets a dataset by ID."""
|
|
16
|
+
return Registry.get(dataset_id)
|
|
17
|
+
|
|
18
|
+
def download(
|
|
19
|
+
self,
|
|
20
|
+
dataset_id: str,
|
|
21
|
+
data_dir: Optional[Path] = None,
|
|
22
|
+
prepare: bool = False,
|
|
23
|
+
lazy: bool = False,
|
|
24
|
+
sample_fraction: Optional[float] = None,
|
|
25
|
+
stratify_col: Optional[str] = None,
|
|
26
|
+
split_fraction: Optional[float] = None,
|
|
27
|
+
**kwargs
|
|
28
|
+
) -> Optional[Dict[str, Any]]:
|
|
29
|
+
"""
|
|
30
|
+
Downloads a dataset.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
dataset_id: The ID of the dataset to download.
|
|
34
|
+
data_dir: Optional directory to download to. Defaults to settings.final_data_dir.
|
|
35
|
+
prepare: If True, convert the dataset to Parquet after download.
|
|
36
|
+
sample_fraction: Optional fraction for sampling (0.0 to 1.0).
|
|
37
|
+
stratify_col: Optional column name for stratified sampling.
|
|
38
|
+
split_fraction: Optional fraction for train/test splitting (e.g. 0.8).
|
|
39
|
+
**kwargs: Additional provider-specific arguments.
|
|
40
|
+
"""
|
|
41
|
+
dataset = self.get_dataset(dataset_id)
|
|
42
|
+
if not dataset:
|
|
43
|
+
raise ValueError(f"Dataset '{dataset_id}' not found in registry.")
|
|
44
|
+
|
|
45
|
+
target_dir = data_dir or settings.final_data_dir
|
|
46
|
+
# Structure: <data_dir>/raw/<dataset_id>/...
|
|
47
|
+
download_path = target_dir / "raw" / dataset.id
|
|
48
|
+
meta_dir = target_dir / "meta"
|
|
49
|
+
|
|
50
|
+
rprint(f"[bold blue]RetailData[/bold blue]: Downloading {dataset.id} to {download_path}")
|
|
51
|
+
|
|
52
|
+
if dataset.provider == "http":
|
|
53
|
+
provider = HTTPProvider()
|
|
54
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
55
|
+
elif dataset.provider == "kaggle":
|
|
56
|
+
provider = KaggleProvider()
|
|
57
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
58
|
+
elif dataset.provider == "hf":
|
|
59
|
+
from retaildata.providers.hf import HFProvider
|
|
60
|
+
provider = HFProvider()
|
|
61
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
62
|
+
elif dataset.provider == "uci":
|
|
63
|
+
from retaildata.providers.uci import UCIProvider
|
|
64
|
+
provider = UCIProvider()
|
|
65
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
66
|
+
elif dataset.provider == "openml":
|
|
67
|
+
from retaildata.providers.openml import OpenMLProvider
|
|
68
|
+
provider = OpenMLProvider()
|
|
69
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
70
|
+
elif dataset.provider == "dlt":
|
|
71
|
+
if dataset.id == "retail_express":
|
|
72
|
+
from retaildata.providers.retail_express import RetailExpressProvider
|
|
73
|
+
provider = RetailExpressProvider()
|
|
74
|
+
else:
|
|
75
|
+
from retaildata.providers.dlt import DLTProvider
|
|
76
|
+
provider = DLTProvider()
|
|
77
|
+
provider.download(dataset, download_path, meta_dir=meta_dir, **kwargs)
|
|
78
|
+
else:
|
|
79
|
+
raise NotImplementedError(f"Provider '{dataset.provider}' not yet supported.")
|
|
80
|
+
|
|
81
|
+
rprint(f"[green]Successfully processed dataset '{dataset.id}'[/green]")
|
|
82
|
+
|
|
83
|
+
if prepare:
|
|
84
|
+
from retaildata.processing.manager import manager as processing_manager
|
|
85
|
+
rprint(f"[bold blue]RetailData[/bold blue]: Preparing {dataset.id} (converting to Parquet)...")
|
|
86
|
+
processing_manager.process_dataset(
|
|
87
|
+
dataset.id,
|
|
88
|
+
data_dir=target_dir,
|
|
89
|
+
sample_fraction=sample_fraction,
|
|
90
|
+
stratify_col=stratify_col,
|
|
91
|
+
split_fraction=split_fraction
|
|
92
|
+
)
|
|
93
|
+
return self.load(dataset.id, data_dir=target_dir, lazy=lazy)
|
|
94
|
+
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
def load(
|
|
98
|
+
self,
|
|
99
|
+
dataset_id: str,
|
|
100
|
+
data_dir: Optional[Path] = None,
|
|
101
|
+
lazy: bool = False,
|
|
102
|
+
standardized: bool = False
|
|
103
|
+
) -> Dict[str, Any]:
|
|
104
|
+
"""
|
|
105
|
+
Loads prepared Parquet files for a dataset.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
dataset_id: The ID of the dataset to load.
|
|
109
|
+
data_dir: Optional directory.
|
|
110
|
+
lazy: If True, returns Polars LazyFrames.
|
|
111
|
+
standardized: If True, uses the dataset's standard_mapping to rename keys (e.g. 'sales', 'calendar').
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
A dictionary mapping keys to Polars DataFrames (or LazyFrames).
|
|
115
|
+
"""
|
|
116
|
+
import polars as pl
|
|
117
|
+
base_dir = data_dir or settings.final_data_dir
|
|
118
|
+
prepared_dir = base_dir / "prepared" / dataset_id
|
|
119
|
+
|
|
120
|
+
if not prepared_dir.exists():
|
|
121
|
+
raise FileNotFoundError(f"Prepared data for '{dataset_id}' not found at {prepared_dir}")
|
|
122
|
+
|
|
123
|
+
data = {}
|
|
124
|
+
for file_path in prepared_dir.glob("*.parquet"):
|
|
125
|
+
key = file_path.stem
|
|
126
|
+
if lazy:
|
|
127
|
+
data[key] = pl.scan_parquet(file_path)
|
|
128
|
+
else:
|
|
129
|
+
data[key] = pl.read_parquet(file_path)
|
|
130
|
+
|
|
131
|
+
# Check for DuckDB files (M7: dlt integration)
|
|
132
|
+
for file_path in prepared_dir.glob("*.duckdb"):
|
|
133
|
+
import duckdb
|
|
134
|
+
con = duckdb.connect(str(file_path))
|
|
135
|
+
tables = con.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'").fetchall()
|
|
136
|
+
for (table_name,) in tables:
|
|
137
|
+
if table_name.startswith("_dlt"):
|
|
138
|
+
continue
|
|
139
|
+
if lazy:
|
|
140
|
+
data[table_name] = pl.read_database(f"SELECT * FROM {table_name}", connection=con).lazy()
|
|
141
|
+
else:
|
|
142
|
+
data[table_name] = pl.read_database(f"SELECT * FROM {table_name}", connection=con)
|
|
143
|
+
con.close()
|
|
144
|
+
|
|
145
|
+
if standardized:
|
|
146
|
+
dataset = self.get_dataset(dataset_id)
|
|
147
|
+
if dataset and dataset.standard_mapping:
|
|
148
|
+
mapped_data = {}
|
|
149
|
+
for std_key, actual_key in dataset.standard_mapping.items():
|
|
150
|
+
# Handle direct match
|
|
151
|
+
if actual_key in data:
|
|
152
|
+
mapped_data[std_key] = data[actual_key]
|
|
153
|
+
# Handle train/test suffixes from processing_manager
|
|
154
|
+
elif f"{actual_key}_train" in data:
|
|
155
|
+
mapped_data[f"{std_key}_train"] = data[f"{actual_key}_train"]
|
|
156
|
+
mapped_data[f"{std_key}_test"] = data[f"{actual_key}_test"]
|
|
157
|
+
return mapped_data
|
|
158
|
+
|
|
159
|
+
return data
|
|
160
|
+
|
|
161
|
+
def split_temporal(
|
|
162
|
+
self,
|
|
163
|
+
dataset_id: str,
|
|
164
|
+
date_col: str,
|
|
165
|
+
split_date: str,
|
|
166
|
+
table_key: Optional[str] = None,
|
|
167
|
+
data_dir: Optional[Path] = None
|
|
168
|
+
) -> Dict[str, Any]:
|
|
169
|
+
"""
|
|
170
|
+
Splits a specific table in the dataset into train and test sets based on a date.
|
|
171
|
+
"""
|
|
172
|
+
import polars as pl
|
|
173
|
+
data = self.load(dataset_id, data_dir=data_dir)
|
|
174
|
+
|
|
175
|
+
# Identify which table to split
|
|
176
|
+
if table_key:
|
|
177
|
+
key = table_key
|
|
178
|
+
else:
|
|
179
|
+
# Try 'sales' from standard mapping, or fallback to first table
|
|
180
|
+
dataset = self.get_dataset(dataset_id)
|
|
181
|
+
if dataset and dataset.standard_mapping:
|
|
182
|
+
key = dataset.standard_mapping.get("sales", list(data.keys())[0])
|
|
183
|
+
else:
|
|
184
|
+
key = list(data.keys())[0]
|
|
185
|
+
|
|
186
|
+
df = data[key]
|
|
187
|
+
if isinstance(df, pl.LazyFrame):
|
|
188
|
+
df = df.collect()
|
|
189
|
+
|
|
190
|
+
# Ensure date_col is datetime/date
|
|
191
|
+
if df[date_col].dtype not in [pl.Datetime, pl.Date]:
|
|
192
|
+
df = df.with_columns(pl.col(date_col).str.to_datetime())
|
|
193
|
+
|
|
194
|
+
# Convert split_date string to datetime for comparison
|
|
195
|
+
split_dt = pl.lit(split_date).str.to_datetime()
|
|
196
|
+
|
|
197
|
+
train = df.filter(pl.col(date_col) < split_dt)
|
|
198
|
+
test = df.filter(pl.col(date_col) >= split_dt)
|
|
199
|
+
|
|
200
|
+
return {"train": train, "test": test}
|
|
201
|
+
|
|
202
|
+
api = RetailDataAPI()
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Dict, Optional
|
|
4
|
+
import json
|
|
5
|
+
from retaildata.config import settings
|
|
6
|
+
from retaildata.datasets.registry import Registry
|
|
7
|
+
|
|
8
|
+
class CacheManager:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.data_dir = settings.final_data_dir
|
|
11
|
+
|
|
12
|
+
def _get_path(self, dataset_id: str, subdir: str) -> Path:
|
|
13
|
+
return self.data_dir / subdir / dataset_id
|
|
14
|
+
|
|
15
|
+
def is_downloaded(self, dataset_id: str) -> bool:
|
|
16
|
+
"""Check if a dataset is downloaded (metadata exists)."""
|
|
17
|
+
meta_path = self._get_path(dataset_id, "meta") / "metadata.json"
|
|
18
|
+
return meta_path.exists()
|
|
19
|
+
|
|
20
|
+
def get_size(self, dataset_id: str) -> int:
|
|
21
|
+
"""Calculate total size of a dataset in bytes."""
|
|
22
|
+
total_size = 0
|
|
23
|
+
for subdir in ["raw", "prepared", "meta"]:
|
|
24
|
+
path = self._get_path(dataset_id, subdir)
|
|
25
|
+
if path.exists():
|
|
26
|
+
for p in path.rglob("*"):
|
|
27
|
+
if p.is_file():
|
|
28
|
+
total_size += p.stat().st_size
|
|
29
|
+
return total_size
|
|
30
|
+
|
|
31
|
+
def list_downloaded(self) -> Dict[str, Dict[str, any]]:
|
|
32
|
+
"""List all downloaded datasets with details."""
|
|
33
|
+
downloaded = {}
|
|
34
|
+
meta_dir = self.data_dir / "meta"
|
|
35
|
+
if not meta_dir.exists():
|
|
36
|
+
return {}
|
|
37
|
+
|
|
38
|
+
for ds_dir in meta_dir.iterdir():
|
|
39
|
+
if ds_dir.is_dir():
|
|
40
|
+
dataset_id = ds_dir.name
|
|
41
|
+
if Registry.get(dataset_id): # Only track known datasets
|
|
42
|
+
size = self.get_size(dataset_id)
|
|
43
|
+
downloaded[dataset_id] = {
|
|
44
|
+
"size": size,
|
|
45
|
+
"path": str(self._get_path(dataset_id, "raw"))
|
|
46
|
+
}
|
|
47
|
+
return downloaded
|
|
48
|
+
|
|
49
|
+
def delete_dataset(self, dataset_id: str) -> bool:
|
|
50
|
+
"""Delete a dataset's files."""
|
|
51
|
+
# Check if it was downloaded first? Or just force delete.
|
|
52
|
+
# We delete from raw, prepared, and meta.
|
|
53
|
+
deleted = False
|
|
54
|
+
for subdir in ["raw", "prepared", "meta"]:
|
|
55
|
+
path = self._get_path(dataset_id, subdir)
|
|
56
|
+
if path.exists():
|
|
57
|
+
shutil.rmtree(path)
|
|
58
|
+
deleted = True
|
|
59
|
+
return deleted
|
|
60
|
+
|
|
61
|
+
def purge_all(self):
|
|
62
|
+
"""Delete all data in the data directory."""
|
|
63
|
+
if self.data_dir.exists():
|
|
64
|
+
# We want to keep the root data dir but empty it, or at least the subdirs we manage.
|
|
65
|
+
# Safety: only delete known subdirs
|
|
66
|
+
for subdir in ["raw", "prepared", "meta"]:
|
|
67
|
+
path = self.data_dir / subdir
|
|
68
|
+
if path.exists():
|
|
69
|
+
shutil.rmtree(path)
|
|
70
|
+
|
|
71
|
+
manager = CacheManager()
|