juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Cached dataset storage wrapper for composable caching layers."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from juniper_data.core.models import DatasetMeta
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
from .base import DatasetStore
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CachedDatasetStore(DatasetStore):
|
|
15
|
+
"""Composable caching wrapper for dataset storage.
|
|
16
|
+
|
|
17
|
+
Wraps a primary store with a cache store for read-through caching.
|
|
18
|
+
Writes go to both stores; reads check cache first, then primary.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
primary = LocalFSDatasetStore(Path("./data"))
|
|
22
|
+
cache = RedisDatasetStore(host="localhost")
|
|
23
|
+
store = CachedDatasetStore(primary, cache)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
primary: DatasetStore,
|
|
29
|
+
cache: DatasetStore,
|
|
30
|
+
write_through: bool = True,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize the cached store.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
primary: Primary (persistent) storage backend.
|
|
36
|
+
cache: Cache storage backend (e.g., Redis, InMemory).
|
|
37
|
+
write_through: If True, writes go to both stores. If False,
|
|
38
|
+
writes only go to primary and cache is populated on read.
|
|
39
|
+
"""
|
|
40
|
+
self._primary = primary
|
|
41
|
+
self._cache = cache
|
|
42
|
+
self._write_through = write_through
|
|
43
|
+
|
|
44
|
+
def save(
|
|
45
|
+
self,
|
|
46
|
+
dataset_id: str,
|
|
47
|
+
meta: DatasetMeta,
|
|
48
|
+
arrays: dict[str, np.ndarray],
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Save dataset to primary store (and optionally cache).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
dataset_id: Unique identifier for the dataset.
|
|
54
|
+
meta: Dataset metadata.
|
|
55
|
+
arrays: Dictionary of numpy arrays.
|
|
56
|
+
"""
|
|
57
|
+
self._primary.save(dataset_id, meta, arrays)
|
|
58
|
+
|
|
59
|
+
if self._write_through:
|
|
60
|
+
with contextlib.suppress(Exception):
|
|
61
|
+
self._cache.save(dataset_id, meta, arrays)
|
|
62
|
+
|
|
63
|
+
def get_meta(self, dataset_id: str) -> DatasetMeta | None:
|
|
64
|
+
"""Get metadata, checking cache first.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
dataset_id: Unique identifier for the dataset.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Dataset metadata if found, None otherwise.
|
|
71
|
+
"""
|
|
72
|
+
with contextlib.suppress(Exception):
|
|
73
|
+
cached = self._cache.get_meta(dataset_id)
|
|
74
|
+
if cached is not None:
|
|
75
|
+
return cached
|
|
76
|
+
return self._primary.get_meta(dataset_id)
|
|
77
|
+
|
|
78
|
+
def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
|
|
79
|
+
"""Get artifact bytes, checking cache first.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dataset_id: Unique identifier for the dataset.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
NPZ bytes if found, None otherwise.
|
|
86
|
+
"""
|
|
87
|
+
with contextlib.suppress(Exception):
|
|
88
|
+
cached = self._cache.get_artifact_bytes(dataset_id)
|
|
89
|
+
if cached is not None:
|
|
90
|
+
return cached
|
|
91
|
+
artifact = self._primary.get_artifact_bytes(dataset_id)
|
|
92
|
+
|
|
93
|
+
if artifact is not None:
|
|
94
|
+
with contextlib.suppress(Exception):
|
|
95
|
+
meta = self._primary.get_meta(dataset_id)
|
|
96
|
+
if meta is not None:
|
|
97
|
+
import io
|
|
98
|
+
|
|
99
|
+
with np.load(io.BytesIO(artifact)) as npz:
|
|
100
|
+
arrays = {k: npz[k] for k in npz.files}
|
|
101
|
+
self._cache.save(dataset_id, meta, arrays)
|
|
102
|
+
return artifact
|
|
103
|
+
|
|
104
|
+
def exists(self, dataset_id: str) -> bool:
|
|
105
|
+
"""Check if dataset exists in either store.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
dataset_id: Unique identifier for the dataset.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
True if the dataset exists, False otherwise.
|
|
112
|
+
"""
|
|
113
|
+
with contextlib.suppress(Exception):
|
|
114
|
+
if self._cache.exists(dataset_id):
|
|
115
|
+
return True
|
|
116
|
+
return self._primary.exists(dataset_id)
|
|
117
|
+
|
|
118
|
+
def delete(self, dataset_id: str) -> bool:
|
|
119
|
+
"""Delete dataset from both stores.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
dataset_id: Unique identifier for the dataset.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
True if the dataset was deleted from primary, False otherwise.
|
|
126
|
+
"""
|
|
127
|
+
with contextlib.suppress(Exception):
|
|
128
|
+
self._cache.delete(dataset_id)
|
|
129
|
+
return self._primary.delete(dataset_id)
|
|
130
|
+
|
|
131
|
+
def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
|
|
132
|
+
"""List datasets from primary store.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
limit: Maximum number of dataset IDs to return.
|
|
136
|
+
offset: Number of dataset IDs to skip.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
List of dataset IDs.
|
|
140
|
+
"""
|
|
141
|
+
return self._primary.list_datasets(limit, offset)
|
|
142
|
+
|
|
143
|
+
def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
|
|
144
|
+
"""Update metadata in both stores.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
dataset_id: Unique identifier for the dataset.
|
|
148
|
+
meta: Updated dataset metadata.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
True if the dataset was updated in primary, False otherwise.
|
|
152
|
+
"""
|
|
153
|
+
result = self._primary.update_meta(dataset_id, meta)
|
|
154
|
+
|
|
155
|
+
if result:
|
|
156
|
+
with contextlib.suppress(Exception):
|
|
157
|
+
self._cache.update_meta(dataset_id, meta)
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
def list_all_metadata(self) -> list[DatasetMeta]:
|
|
161
|
+
"""List all metadata from primary store.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
List of all DatasetMeta objects.
|
|
165
|
+
"""
|
|
166
|
+
return self._primary.list_all_metadata()
|
|
167
|
+
|
|
168
|
+
def invalidate_cache(self, dataset_id: str) -> bool:
|
|
169
|
+
"""Invalidate a specific entry in the cache.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
dataset_id: Unique identifier for the dataset.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
True if entry was removed from cache, False otherwise.
|
|
176
|
+
"""
|
|
177
|
+
try:
|
|
178
|
+
return self._cache.delete(dataset_id)
|
|
179
|
+
except Exception:
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
def warm_cache(self, dataset_ids: list[str] | None = None) -> int:
|
|
183
|
+
"""Populate cache from primary store.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
dataset_ids: Specific IDs to cache, or None for all.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Number of datasets cached.
|
|
190
|
+
"""
|
|
191
|
+
if dataset_ids is None:
|
|
192
|
+
dataset_ids = self._primary.list_datasets(limit=10000)
|
|
193
|
+
|
|
194
|
+
cached_count = 0
|
|
195
|
+
for dataset_id in dataset_ids:
|
|
196
|
+
try:
|
|
197
|
+
meta = self._primary.get_meta(dataset_id)
|
|
198
|
+
artifact = self._primary.get_artifact_bytes(dataset_id)
|
|
199
|
+
|
|
200
|
+
if meta is not None and artifact is not None:
|
|
201
|
+
import io
|
|
202
|
+
|
|
203
|
+
with np.load(io.BytesIO(artifact)) as npz:
|
|
204
|
+
arrays = {k: npz[k] for k in npz.files}
|
|
205
|
+
self._cache.save(dataset_id, meta, arrays)
|
|
206
|
+
cached_count += 1
|
|
207
|
+
except Exception:
|
|
208
|
+
logger.warning("Failed to cache dataset %s", dataset_id, exc_info=True)
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
return cached_count
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Hugging Face datasets integration for loading external datasets."""
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from juniper_data.core.models import DatasetMeta
|
|
9
|
+
|
|
10
|
+
from .base import DatasetStore
|
|
11
|
+
from .memory import InMemoryDatasetStore
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from datasets import load_dataset as hf_load_dataset
|
|
15
|
+
|
|
16
|
+
HF_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
HF_AVAILABLE = False
|
|
19
|
+
hf_load_dataset = None # type: ignore[assignment]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class HuggingFaceDatasetStore(DatasetStore):
|
|
23
|
+
"""Read-only store for loading datasets from Hugging Face Hub.
|
|
24
|
+
|
|
25
|
+
Loads datasets from Hugging Face and converts them to JuniperData format.
|
|
26
|
+
Primarily used as a data source, not for persistent storage.
|
|
27
|
+
|
|
28
|
+
Requires the `datasets` package: pip install datasets
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
cache_store: DatasetStore | None = None,
|
|
34
|
+
cache_dir: str | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Initialize the HF store.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
cache_store: Optional store for caching loaded datasets.
|
|
40
|
+
cache_dir: Optional local directory for HF dataset cache.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ImportError: If datasets package is not installed.
|
|
44
|
+
"""
|
|
45
|
+
if not HF_AVAILABLE:
|
|
46
|
+
raise ImportError("Hugging Face datasets package not installed. Install with: pip install datasets")
|
|
47
|
+
|
|
48
|
+
self._cache_store = cache_store or InMemoryDatasetStore()
|
|
49
|
+
self._cache_dir = cache_dir
|
|
50
|
+
|
|
51
|
+
def load_hf_dataset(
|
|
52
|
+
self,
|
|
53
|
+
dataset_name: str,
|
|
54
|
+
config_name: str | None = None,
|
|
55
|
+
split: str = "train",
|
|
56
|
+
feature_columns: list[str] | None = None,
|
|
57
|
+
label_column: str = "label",
|
|
58
|
+
n_samples: int | None = None,
|
|
59
|
+
seed: int | None = None,
|
|
60
|
+
flatten: bool = True,
|
|
61
|
+
normalize: bool = True,
|
|
62
|
+
one_hot_labels: bool = True,
|
|
63
|
+
train_ratio: float = 0.8,
|
|
64
|
+
) -> tuple[str, DatasetMeta, dict[str, np.ndarray]]:
|
|
65
|
+
"""Load a dataset from Hugging Face and convert to JuniperData format.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataset_name: HF dataset name (e.g., "mnist", "fashion_mnist").
|
|
69
|
+
config_name: Optional dataset configuration.
|
|
70
|
+
split: Dataset split to load.
|
|
71
|
+
feature_columns: Column names for features (auto-detected if None).
|
|
72
|
+
label_column: Column name for labels.
|
|
73
|
+
n_samples: Optional limit on number of samples.
|
|
74
|
+
seed: Random seed for shuffling/sampling.
|
|
75
|
+
flatten: Flatten image data to 1D.
|
|
76
|
+
normalize: Normalize features to [0, 1].
|
|
77
|
+
one_hot_labels: One-hot encode labels.
|
|
78
|
+
train_ratio: Ratio for train/test split.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Tuple of (dataset_id, metadata, arrays).
|
|
82
|
+
"""
|
|
83
|
+
# assert hf_load_dataset is not None
|
|
84
|
+
|
|
85
|
+
ds = hf_load_dataset( # nosec B615
|
|
86
|
+
dataset_name,
|
|
87
|
+
config_name,
|
|
88
|
+
split=split,
|
|
89
|
+
cache_dir=self._cache_dir,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if seed is not None:
|
|
93
|
+
ds = ds.shuffle(seed=seed)
|
|
94
|
+
|
|
95
|
+
if n_samples is not None:
|
|
96
|
+
ds = ds.select(range(min(n_samples, len(ds))))
|
|
97
|
+
|
|
98
|
+
X, y, n_classes = self._extract_features_labels(
|
|
99
|
+
ds,
|
|
100
|
+
feature_columns=feature_columns,
|
|
101
|
+
label_column=label_column,
|
|
102
|
+
flatten=flatten,
|
|
103
|
+
normalize=normalize,
|
|
104
|
+
one_hot_labels=one_hot_labels,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
n_train = int(len(X) * train_ratio)
|
|
108
|
+
X_train, X_test = X[:n_train], X[n_train:]
|
|
109
|
+
y_train, y_test = y[:n_train], y[n_train:]
|
|
110
|
+
|
|
111
|
+
config_suffix = f"-{config_name}" if config_name else ""
|
|
112
|
+
dataset_id = f"hf-{dataset_name}{config_suffix}-{len(X)}"
|
|
113
|
+
|
|
114
|
+
class_indices = y.argmax(axis=1) if one_hot_labels else y.flatten().astype(int)
|
|
115
|
+
class_distribution = {str(i): int((class_indices == i).sum()) for i in range(n_classes)}
|
|
116
|
+
meta = DatasetMeta(
|
|
117
|
+
dataset_id=dataset_id,
|
|
118
|
+
generator="huggingface",
|
|
119
|
+
generator_version="1.0.0",
|
|
120
|
+
params={
|
|
121
|
+
"dataset_name": dataset_name,
|
|
122
|
+
"config_name": config_name,
|
|
123
|
+
"split": split,
|
|
124
|
+
"n_samples": len(X),
|
|
125
|
+
"seed": seed,
|
|
126
|
+
"flatten": flatten,
|
|
127
|
+
"normalize": normalize,
|
|
128
|
+
"one_hot_labels": one_hot_labels,
|
|
129
|
+
},
|
|
130
|
+
n_samples=len(X),
|
|
131
|
+
n_features=X.shape[1] if len(X.shape) > 1 else 1,
|
|
132
|
+
n_classes=n_classes,
|
|
133
|
+
n_train=n_train,
|
|
134
|
+
n_test=len(X) - n_train,
|
|
135
|
+
class_distribution=class_distribution,
|
|
136
|
+
created_at=datetime.now(UTC),
|
|
137
|
+
tags=["huggingface", dataset_name],
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
arrays = {
|
|
141
|
+
"X_train": X_train,
|
|
142
|
+
"y_train": y_train,
|
|
143
|
+
"X_test": X_test,
|
|
144
|
+
"y_test": y_test,
|
|
145
|
+
"X_full": X,
|
|
146
|
+
"y_full": y,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
self._cache_store.save(dataset_id, meta, arrays)
|
|
150
|
+
|
|
151
|
+
return dataset_id, meta, arrays
|
|
152
|
+
|
|
153
|
+
def _extract_features_labels(
|
|
154
|
+
self,
|
|
155
|
+
ds: Any,
|
|
156
|
+
feature_columns: list[str] | None,
|
|
157
|
+
label_column: str,
|
|
158
|
+
flatten: bool,
|
|
159
|
+
normalize: bool,
|
|
160
|
+
one_hot_labels: bool,
|
|
161
|
+
) -> tuple[np.ndarray, np.ndarray, int]:
|
|
162
|
+
"""Extract features and labels from HF dataset.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Tuple of (X, y, n_classes).
|
|
166
|
+
"""
|
|
167
|
+
if feature_columns is None:
|
|
168
|
+
feature_columns = [col for col in ds.column_names if col not in (label_column, "idx", "id")]
|
|
169
|
+
|
|
170
|
+
if len(feature_columns) == 1 and "image" in feature_columns[0].lower():
|
|
171
|
+
X = self._extract_images(ds, feature_columns[0], flatten, normalize)
|
|
172
|
+
else:
|
|
173
|
+
features = []
|
|
174
|
+
for col in feature_columns:
|
|
175
|
+
col_data = ds[col]
|
|
176
|
+
if hasattr(col_data[0], "numpy"):
|
|
177
|
+
col_data = [x.numpy() for x in col_data]
|
|
178
|
+
features.append(np.array(col_data))
|
|
179
|
+
X = np.column_stack(features) if len(features) > 1 else features[0]
|
|
180
|
+
X = X.astype(np.float32)
|
|
181
|
+
if normalize and X.max() > 1.0:
|
|
182
|
+
X = X / X.max()
|
|
183
|
+
|
|
184
|
+
labels = np.array(ds[label_column])
|
|
185
|
+
n_classes = int(labels.max()) + 1
|
|
186
|
+
|
|
187
|
+
if one_hot_labels:
|
|
188
|
+
y = np.zeros((len(labels), n_classes), dtype=np.float32)
|
|
189
|
+
y[np.arange(len(labels)), labels] = 1.0
|
|
190
|
+
else:
|
|
191
|
+
y = labels.astype(np.float32).reshape(-1, 1)
|
|
192
|
+
|
|
193
|
+
return X, y, n_classes
|
|
194
|
+
|
|
195
|
+
def _extract_images(
|
|
196
|
+
self,
|
|
197
|
+
ds: Any,
|
|
198
|
+
image_column: str,
|
|
199
|
+
flatten: bool,
|
|
200
|
+
normalize: bool,
|
|
201
|
+
) -> np.ndarray:
|
|
202
|
+
"""Extract and preprocess image data."""
|
|
203
|
+
images = []
|
|
204
|
+
for item in ds:
|
|
205
|
+
img = item[image_column]
|
|
206
|
+
if hasattr(img, "convert"):
|
|
207
|
+
img = np.array(img.convert("L"))
|
|
208
|
+
elif hasattr(img, "numpy"):
|
|
209
|
+
img = img.numpy()
|
|
210
|
+
else:
|
|
211
|
+
img = np.array(img)
|
|
212
|
+
images.append(img)
|
|
213
|
+
|
|
214
|
+
X = np.stack(images)
|
|
215
|
+
|
|
216
|
+
X = X.astype(np.float32) / 255.0 if normalize else X.astype(np.float32)
|
|
217
|
+
if flatten:
|
|
218
|
+
X = X.reshape(len(X), -1)
|
|
219
|
+
|
|
220
|
+
return X
|
|
221
|
+
|
|
222
|
+
def save(
|
|
223
|
+
self,
|
|
224
|
+
dataset_id: str,
|
|
225
|
+
meta: DatasetMeta,
|
|
226
|
+
arrays: dict[str, np.ndarray],
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Save to cache store."""
|
|
229
|
+
self._cache_store.save(dataset_id, meta, arrays)
|
|
230
|
+
|
|
231
|
+
def get_meta(self, dataset_id: str) -> DatasetMeta | None:
|
|
232
|
+
"""Get from cache store."""
|
|
233
|
+
return self._cache_store.get_meta(dataset_id)
|
|
234
|
+
|
|
235
|
+
def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
|
|
236
|
+
"""Get from cache store."""
|
|
237
|
+
return self._cache_store.get_artifact_bytes(dataset_id)
|
|
238
|
+
|
|
239
|
+
def exists(self, dataset_id: str) -> bool:
|
|
240
|
+
"""Check cache store."""
|
|
241
|
+
return self._cache_store.exists(dataset_id)
|
|
242
|
+
|
|
243
|
+
def delete(self, dataset_id: str) -> bool:
|
|
244
|
+
"""Delete from cache store."""
|
|
245
|
+
return self._cache_store.delete(dataset_id)
|
|
246
|
+
|
|
247
|
+
def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
|
|
248
|
+
"""List from cache store."""
|
|
249
|
+
return self._cache_store.list_datasets(limit, offset)
|
|
250
|
+
|
|
251
|
+
def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
|
|
252
|
+
"""Update in cache store."""
|
|
253
|
+
return self._cache_store.update_meta(dataset_id, meta)
|
|
254
|
+
|
|
255
|
+
def list_all_metadata(self) -> list[DatasetMeta]:
|
|
256
|
+
"""List from cache store."""
|
|
257
|
+
return self._cache_store.list_all_metadata()
|