data-designer 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/config/column_configs.py +29 -4
- data_designer/config/datastore.py +70 -34
- data_designer/config/default_model_settings.py +12 -8
- data_designer/config/sampler_params.py +16 -2
- data_designer/engine/resources/seed_dataset_data_store.py +20 -2
- data_designer/interface/data_designer.py +24 -3
- {data_designer-0.1.0.dist-info → data_designer-0.1.2.dist-info}/METADATA +27 -13
- {data_designer-0.1.0.dist-info → data_designer-0.1.2.dist-info}/RECORD +12 -12
- {data_designer-0.1.0.dist-info → data_designer-0.1.2.dist-info}/WHEEL +0 -0
- {data_designer-0.1.0.dist-info → data_designer-0.1.2.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.0.dist-info → data_designer-0.1.2.dist-info}/licenses/LICENSE +0 -0
data_designer/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC
|
|
5
|
-
from typing import Literal, Optional, Type, Union
|
|
5
|
+
from typing import Annotated, Literal, Optional, Type, Union
|
|
6
6
|
|
|
7
|
-
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
from pydantic import BaseModel, Discriminator, Field, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
10
|
from .base import ConfigBase
|
|
@@ -89,11 +89,36 @@ class SamplerColumnConfig(SingleColumnConfig):
|
|
|
89
89
|
"""
|
|
90
90
|
|
|
91
91
|
sampler_type: SamplerType
|
|
92
|
-
params: SamplerParamsT
|
|
93
|
-
conditional_params: dict[str, SamplerParamsT] = {}
|
|
92
|
+
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
|
|
93
|
+
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
|
|
94
94
|
convert_to: Optional[str] = None
|
|
95
95
|
column_type: Literal["sampler"] = "sampler"
|
|
96
96
|
|
|
97
|
+
@model_validator(mode="before")
|
|
98
|
+
@classmethod
|
|
99
|
+
def inject_sampler_type_into_params(cls, data: dict) -> dict:
|
|
100
|
+
"""Inject sampler_type into params dict to enable discriminated union resolution.
|
|
101
|
+
|
|
102
|
+
This allows users to pass params as a simple dict without the sampler_type field,
|
|
103
|
+
which will be automatically added based on the outer sampler_type field.
|
|
104
|
+
"""
|
|
105
|
+
if isinstance(data, dict):
|
|
106
|
+
sampler_type = data.get("sampler_type")
|
|
107
|
+
params = data.get("params")
|
|
108
|
+
|
|
109
|
+
# If params is a dict and doesn't have sampler_type, inject it
|
|
110
|
+
if sampler_type and isinstance(params, dict) and "sampler_type" not in params:
|
|
111
|
+
data["params"] = {"sampler_type": sampler_type, **params}
|
|
112
|
+
|
|
113
|
+
# Handle conditional_params similarly
|
|
114
|
+
conditional_params = data.get("conditional_params")
|
|
115
|
+
if conditional_params and isinstance(conditional_params, dict):
|
|
116
|
+
for condition, cond_params in conditional_params.items():
|
|
117
|
+
if isinstance(cond_params, dict) and "sampler_type" not in cond_params:
|
|
118
|
+
data["conditional_params"][condition] = {"sampler_type": sampler_type, **cond_params}
|
|
119
|
+
|
|
120
|
+
return data
|
|
121
|
+
|
|
97
122
|
|
|
98
123
|
class LLMTextColumnConfig(SingleColumnConfig):
|
|
99
124
|
"""Configuration for text generation columns using Large Language Models.
|
|
@@ -31,34 +31,37 @@ class DatastoreSettings(BaseModel):
|
|
|
31
31
|
token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def get_file_column_names(
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
41
|
-
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
42
|
-
file_path = matching_files[0]
|
|
34
|
+
def get_file_column_names(file_reference: Union[str, Path, HfFileSystem], file_type: str) -> list[str]:
|
|
35
|
+
"""Get column names from a dataset file.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
file_reference: Path to the dataset file, or an HfFileSystem object.
|
|
39
|
+
file_type: Type of the dataset file. Must be one of: 'parquet', 'json', 'jsonl', 'csv'.
|
|
43
40
|
|
|
41
|
+
Raises:
|
|
42
|
+
InvalidFilePathError: If the file type is not supported.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of column names.
|
|
46
|
+
"""
|
|
44
47
|
if file_type == "parquet":
|
|
45
48
|
try:
|
|
46
|
-
schema = pq.read_schema(
|
|
49
|
+
schema = pq.read_schema(file_reference)
|
|
47
50
|
if hasattr(schema, "names"):
|
|
48
51
|
return schema.names
|
|
49
52
|
else:
|
|
50
53
|
return [field.name for field in schema]
|
|
51
54
|
except Exception as e:
|
|
52
|
-
logger.warning(f"Failed to process parquet file {
|
|
55
|
+
logger.warning(f"Failed to process parquet file {file_reference}: {e}")
|
|
53
56
|
return []
|
|
54
57
|
elif file_type in ["json", "jsonl"]:
|
|
55
|
-
return pd.read_json(
|
|
58
|
+
return pd.read_json(file_reference, orient="records", lines=True, nrows=1).columns.tolist()
|
|
56
59
|
elif file_type == "csv":
|
|
57
60
|
try:
|
|
58
|
-
df = pd.read_csv(
|
|
61
|
+
df = pd.read_csv(file_reference, nrows=1)
|
|
59
62
|
return df.columns.tolist()
|
|
60
63
|
except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:
|
|
61
|
-
logger.warning(f"Failed to process CSV file {
|
|
64
|
+
logger.warning(f"Failed to process CSV file {file_reference}: {e}")
|
|
62
65
|
return []
|
|
63
66
|
else:
|
|
64
67
|
raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}")
|
|
@@ -66,12 +69,36 @@ def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[s
|
|
|
66
69
|
|
|
67
70
|
def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]:
|
|
68
71
|
if hasattr(seed_dataset_reference, "datastore_settings"):
|
|
69
|
-
return
|
|
72
|
+
return fetch_seed_dataset_column_names_from_datastore(
|
|
70
73
|
seed_dataset_reference.repo_id,
|
|
71
74
|
seed_dataset_reference.filename,
|
|
72
75
|
seed_dataset_reference.datastore_settings,
|
|
73
76
|
)
|
|
74
|
-
return
|
|
77
|
+
return fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def fetch_seed_dataset_column_names_from_datastore(
|
|
81
|
+
repo_id: str,
|
|
82
|
+
filename: str,
|
|
83
|
+
datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
|
|
84
|
+
) -> list[str]:
|
|
85
|
+
file_type = filename.split(".")[-1]
|
|
86
|
+
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
|
|
87
|
+
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
|
|
88
|
+
|
|
89
|
+
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
90
|
+
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True)
|
|
91
|
+
|
|
92
|
+
file_path = _extract_single_file_path_from_glob_pattern_if_present(f"datasets/{repo_id}/{filename}", fs=fs)
|
|
93
|
+
|
|
94
|
+
with fs.open(file_path) as f:
|
|
95
|
+
return get_file_column_names(f, file_type)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
|
|
99
|
+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
|
|
100
|
+
dataset_path = _extract_single_file_path_from_glob_pattern_if_present(dataset_path)
|
|
101
|
+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
|
|
75
102
|
|
|
76
103
|
|
|
77
104
|
def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings:
|
|
@@ -114,25 +141,34 @@ def upload_to_hf_hub(
|
|
|
114
141
|
return f"{repo_id}/{filename}"
|
|
115
142
|
|
|
116
143
|
|
|
117
|
-
def
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
file_type = filename.split(".")[-1]
|
|
123
|
-
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
|
|
124
|
-
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
|
|
125
|
-
|
|
126
|
-
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
127
|
-
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token)
|
|
128
|
-
|
|
129
|
-
with fs.open(f"datasets/{repo_id}/{filename}") as f:
|
|
130
|
-
return get_file_column_names(f, file_type)
|
|
131
|
-
|
|
144
|
+
def _extract_single_file_path_from_glob_pattern_if_present(
|
|
145
|
+
file_path: str | Path,
|
|
146
|
+
fs: HfFileSystem | None = None,
|
|
147
|
+
) -> Path:
|
|
148
|
+
file_path = Path(file_path)
|
|
132
149
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
150
|
+
# no glob pattern
|
|
151
|
+
if "*" not in str(file_path):
|
|
152
|
+
return file_path
|
|
153
|
+
|
|
154
|
+
# glob pattern with HfFileSystem
|
|
155
|
+
if fs is not None:
|
|
156
|
+
file_to_check = None
|
|
157
|
+
file_extension = file_path.name.split(".")[-1]
|
|
158
|
+
for file in fs.ls(str(file_path.parent)):
|
|
159
|
+
filename = file["name"]
|
|
160
|
+
if filename.endswith(f".{file_extension}"):
|
|
161
|
+
file_to_check = filename
|
|
162
|
+
if file_to_check is None:
|
|
163
|
+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
164
|
+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
165
|
+
return Path(file_to_check)
|
|
166
|
+
|
|
167
|
+
# glob pattern with local file system
|
|
168
|
+
if not (matching_files := sorted(file_path.parent.glob(file_path.name))):
|
|
169
|
+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
170
|
+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
171
|
+
return matching_files[0]
|
|
136
172
|
|
|
137
173
|
|
|
138
174
|
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
from functools import lru_cache
|
|
6
6
|
import logging
|
|
7
|
+
import os
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from typing import Any, Literal, Optional
|
|
9
10
|
|
|
@@ -15,7 +16,6 @@ from .utils.constants import (
|
|
|
15
16
|
PREDEFINED_PROVIDERS,
|
|
16
17
|
PREDEFINED_PROVIDERS_MODEL_MAP,
|
|
17
18
|
)
|
|
18
|
-
from .utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo
|
|
19
19
|
from .utils.io_helpers import load_config_file, save_config_file
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
@@ -75,7 +75,15 @@ def get_default_model_configs() -> list[ModelConfig]:
|
|
|
75
75
|
config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
|
|
76
76
|
if "model_configs" in config_dict:
|
|
77
77
|
return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
|
|
78
|
-
|
|
78
|
+
return []
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_default_model_providers_missing_api_keys() -> list[str]:
|
|
82
|
+
missing_api_keys = []
|
|
83
|
+
for predefined_provider in PREDEFINED_PROVIDERS:
|
|
84
|
+
if os.environ.get(predefined_provider["api_key"]) is None:
|
|
85
|
+
missing_api_keys.append(predefined_provider["api_key"])
|
|
86
|
+
return missing_api_keys
|
|
79
87
|
|
|
80
88
|
|
|
81
89
|
def get_default_providers() -> list[ModelProvider]:
|
|
@@ -91,21 +99,17 @@ def get_default_provider_name() -> Optional[str]:
|
|
|
91
99
|
|
|
92
100
|
def resolve_seed_default_model_settings() -> None:
|
|
93
101
|
if not MODEL_CONFIGS_FILE_PATH.exists():
|
|
94
|
-
logger.
|
|
102
|
+
logger.debug(
|
|
95
103
|
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
|
|
96
104
|
)
|
|
97
|
-
config_builder_info = ConfigBuilderInfo(model_configs=get_builtin_model_configs())
|
|
98
|
-
config_builder_info.display(info_type=InfoType.MODEL_CONFIGS)
|
|
99
105
|
save_config_file(
|
|
100
106
|
MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
|
|
101
107
|
)
|
|
102
108
|
|
|
103
109
|
if not MODEL_PROVIDERS_FILE_PATH.exists():
|
|
104
|
-
logger.
|
|
110
|
+
logger.debug(
|
|
105
111
|
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
|
|
106
112
|
)
|
|
107
|
-
interface_info = InterfaceInfo(model_providers=get_builtin_model_providers())
|
|
108
|
-
interface_info.display(info_type=InfoType.MODEL_PROVIDERS)
|
|
109
113
|
save_config_file(
|
|
110
114
|
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
|
|
111
115
|
)
|
|
@@ -66,6 +66,7 @@ class CategorySamplerParams(ConfigBase):
|
|
|
66
66
|
"Larger values will be sampled with higher probability."
|
|
67
67
|
),
|
|
68
68
|
)
|
|
69
|
+
sampler_type: Literal[SamplerType.CATEGORY] = SamplerType.CATEGORY
|
|
69
70
|
|
|
70
71
|
@model_validator(mode="after")
|
|
71
72
|
def _normalize_weights_if_needed(self) -> Self:
|
|
@@ -106,6 +107,7 @@ class DatetimeSamplerParams(ConfigBase):
|
|
|
106
107
|
default="D",
|
|
107
108
|
description="Sampling units, e.g. the smallest possible time interval between samples.",
|
|
108
109
|
)
|
|
110
|
+
sampler_type: Literal[SamplerType.DATETIME] = SamplerType.DATETIME
|
|
109
111
|
|
|
110
112
|
@field_validator("start", "end")
|
|
111
113
|
@classmethod
|
|
@@ -136,6 +138,7 @@ class SubcategorySamplerParams(ConfigBase):
|
|
|
136
138
|
...,
|
|
137
139
|
description="Mapping from each value of parent category to a list of subcategory values.",
|
|
138
140
|
)
|
|
141
|
+
sampler_type: Literal[SamplerType.SUBCATEGORY] = SamplerType.SUBCATEGORY
|
|
139
142
|
|
|
140
143
|
|
|
141
144
|
class TimeDeltaSamplerParams(ConfigBase):
|
|
@@ -187,6 +190,7 @@ class TimeDeltaSamplerParams(ConfigBase):
|
|
|
187
190
|
default="D",
|
|
188
191
|
description="Sampling units, e.g. the smallest possible time interval between samples.",
|
|
189
192
|
)
|
|
193
|
+
sampler_type: Literal[SamplerType.TIMEDELTA] = SamplerType.TIMEDELTA
|
|
190
194
|
|
|
191
195
|
@model_validator(mode="after")
|
|
192
196
|
def _validate_min_less_than_max(self) -> Self:
|
|
@@ -219,6 +223,7 @@ class UUIDSamplerParams(ConfigBase):
|
|
|
219
223
|
default=False,
|
|
220
224
|
description="If true, all letters in the UUID will be capitalized.",
|
|
221
225
|
)
|
|
226
|
+
sampler_type: Literal[SamplerType.UUID] = SamplerType.UUID
|
|
222
227
|
|
|
223
228
|
@property
|
|
224
229
|
def last_index(self) -> int:
|
|
@@ -257,6 +262,7 @@ class ScipySamplerParams(ConfigBase):
|
|
|
257
262
|
decimal_places: Optional[int] = Field(
|
|
258
263
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
259
264
|
)
|
|
265
|
+
sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
|
|
260
266
|
|
|
261
267
|
|
|
262
268
|
class BinomialSamplerParams(ConfigBase):
|
|
@@ -273,6 +279,7 @@ class BinomialSamplerParams(ConfigBase):
|
|
|
273
279
|
|
|
274
280
|
n: int = Field(..., description="Number of trials.")
|
|
275
281
|
p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0)
|
|
282
|
+
sampler_type: Literal[SamplerType.BINOMIAL] = SamplerType.BINOMIAL
|
|
276
283
|
|
|
277
284
|
|
|
278
285
|
class BernoulliSamplerParams(ConfigBase):
|
|
@@ -288,6 +295,7 @@ class BernoulliSamplerParams(ConfigBase):
|
|
|
288
295
|
"""
|
|
289
296
|
|
|
290
297
|
p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0)
|
|
298
|
+
sampler_type: Literal[SamplerType.BERNOULLI] = SamplerType.BERNOULLI
|
|
291
299
|
|
|
292
300
|
|
|
293
301
|
class BernoulliMixtureSamplerParams(ConfigBase):
|
|
@@ -327,6 +335,7 @@ class BernoulliMixtureSamplerParams(ConfigBase):
|
|
|
327
335
|
...,
|
|
328
336
|
description="Parameters of the scipy.stats distribution given in `dist_name`.",
|
|
329
337
|
)
|
|
338
|
+
sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = SamplerType.BERNOULLI_MIXTURE
|
|
330
339
|
|
|
331
340
|
|
|
332
341
|
class GaussianSamplerParams(ConfigBase):
|
|
@@ -350,6 +359,7 @@ class GaussianSamplerParams(ConfigBase):
|
|
|
350
359
|
decimal_places: Optional[int] = Field(
|
|
351
360
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
352
361
|
)
|
|
362
|
+
sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
|
|
353
363
|
|
|
354
364
|
|
|
355
365
|
class PoissonSamplerParams(ConfigBase):
|
|
@@ -369,6 +379,7 @@ class PoissonSamplerParams(ConfigBase):
|
|
|
369
379
|
"""
|
|
370
380
|
|
|
371
381
|
mean: float = Field(..., description="Mean number of events in a fixed interval.")
|
|
382
|
+
sampler_type: Literal[SamplerType.POISSON] = SamplerType.POISSON
|
|
372
383
|
|
|
373
384
|
|
|
374
385
|
class UniformSamplerParams(ConfigBase):
|
|
@@ -390,6 +401,7 @@ class UniformSamplerParams(ConfigBase):
|
|
|
390
401
|
decimal_places: Optional[int] = Field(
|
|
391
402
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
392
403
|
)
|
|
404
|
+
sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
|
|
393
405
|
|
|
394
406
|
|
|
395
407
|
#########################################
|
|
@@ -470,11 +482,12 @@ class PersonSamplerParams(ConfigBase):
|
|
|
470
482
|
default=False,
|
|
471
483
|
description="If True, then append synthetic persona columns to each generated person.",
|
|
472
484
|
)
|
|
485
|
+
sampler_type: Literal[SamplerType.PERSON] = SamplerType.PERSON
|
|
473
486
|
|
|
474
487
|
@property
|
|
475
488
|
def generator_kwargs(self) -> list[str]:
|
|
476
489
|
"""Keyword arguments to pass to the person generator."""
|
|
477
|
-
return [f for f in list(PersonSamplerParams.model_fields) if f
|
|
490
|
+
return [f for f in list(PersonSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
|
|
478
491
|
|
|
479
492
|
@property
|
|
480
493
|
def people_gen_key(self) -> str:
|
|
@@ -533,11 +546,12 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
533
546
|
min_length=2,
|
|
534
547
|
max_length=2,
|
|
535
548
|
)
|
|
549
|
+
sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = SamplerType.PERSON_FROM_FAKER
|
|
536
550
|
|
|
537
551
|
@property
|
|
538
552
|
def generator_kwargs(self) -> list[str]:
|
|
539
553
|
"""Keyword arguments to pass to the person generator."""
|
|
540
|
-
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f
|
|
554
|
+
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
|
|
541
555
|
|
|
542
556
|
@property
|
|
543
557
|
def people_gen_key(self) -> str:
|
|
@@ -42,11 +42,29 @@ class HfHubSeedDatasetDataStore(SeedDatasetDataStore):
|
|
|
42
42
|
|
|
43
43
|
def __init__(self, endpoint: str, token: str | None):
|
|
44
44
|
self.hfapi = HfApi(endpoint=endpoint, token=token)
|
|
45
|
-
self.
|
|
45
|
+
self.endpoint = endpoint
|
|
46
|
+
self.token = token
|
|
46
47
|
|
|
47
48
|
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
|
|
49
|
+
"""Create a DuckDB connection with a fresh HfFileSystem registered.
|
|
50
|
+
|
|
51
|
+
Creates a new HfFileSystem instance for each connection to ensure file metadata
|
|
52
|
+
is fetched fresh from the datastore, avoiding cache-related issues when reading
|
|
53
|
+
recently updated parquet files.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A DuckDB connection with the HfFileSystem registered for hf:// URI support.
|
|
57
|
+
"""
|
|
58
|
+
# Use skip_instance_cache to avoid fsspec-level caching
|
|
59
|
+
hffs = HfFileSystem(endpoint=self.endpoint, token=self.token, skip_instance_cache=True)
|
|
60
|
+
|
|
61
|
+
# Clear all internal caches to avoid stale metadata issues
|
|
62
|
+
# HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded
|
|
63
|
+
if hasattr(hffs, "dircache"):
|
|
64
|
+
hffs.dircache.clear()
|
|
65
|
+
|
|
48
66
|
conn = duckdb.connect()
|
|
49
|
-
conn.register_filesystem(
|
|
67
|
+
conn.register_filesystem(hffs)
|
|
50
68
|
return conn
|
|
51
69
|
|
|
52
70
|
def get_dataset_uri(self, file_id: str) -> str:
|
|
@@ -10,6 +10,7 @@ from data_designer.config.analysis.dataset_profiler import DatasetProfilerResult
|
|
|
10
10
|
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
11
11
|
from data_designer.config.default_model_settings import (
|
|
12
12
|
get_default_model_configs,
|
|
13
|
+
get_default_model_providers_missing_api_keys,
|
|
13
14
|
get_default_provider_name,
|
|
14
15
|
get_default_providers,
|
|
15
16
|
resolve_seed_default_model_settings,
|
|
@@ -26,8 +27,9 @@ from data_designer.config.utils.constants import (
|
|
|
26
27
|
MANAGED_ASSETS_PATH,
|
|
27
28
|
MODEL_CONFIGS_FILE_PATH,
|
|
28
29
|
MODEL_PROVIDERS_FILE_PATH,
|
|
30
|
+
PREDEFINED_PROVIDERS,
|
|
29
31
|
)
|
|
30
|
-
from data_designer.config.utils.info import InterfaceInfo
|
|
32
|
+
from data_designer.config.utils.info import InfoType, InterfaceInfo
|
|
31
33
|
from data_designer.config.utils.io_helpers import write_seed_dataset
|
|
32
34
|
from data_designer.config.utils.misc import can_run_data_designer_locally
|
|
33
35
|
from data_designer.engine.analysis.dataset_profiler import (
|
|
@@ -103,7 +105,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
103
105
|
self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
|
|
104
106
|
self._buffer_size = DEFAULT_BUFFER_SIZE
|
|
105
107
|
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
|
|
106
|
-
self._model_providers =
|
|
108
|
+
self._model_providers = self._resolve_model_providers(model_providers)
|
|
107
109
|
self._model_provider_registry = resolve_model_provider_registry(
|
|
108
110
|
self._model_providers, get_default_provider_name()
|
|
109
111
|
)
|
|
@@ -151,7 +153,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
151
153
|
Returns:
|
|
152
154
|
InterfaceInfo object with information about the Data Designer interface.
|
|
153
155
|
"""
|
|
154
|
-
return
|
|
156
|
+
return self._get_interface_info(self._model_providers)
|
|
155
157
|
|
|
156
158
|
def create(
|
|
157
159
|
self,
|
|
@@ -307,6 +309,22 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
307
309
|
raise InvalidBufferValueError("Buffer size must be greater than 0.")
|
|
308
310
|
self._buffer_size = buffer_size
|
|
309
311
|
|
|
312
|
+
def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
|
|
313
|
+
if model_providers is None:
|
|
314
|
+
if can_run_data_designer_locally():
|
|
315
|
+
model_providers = get_default_providers()
|
|
316
|
+
missing_api_keys = get_default_model_providers_missing_api_keys()
|
|
317
|
+
if len(missing_api_keys) == len(PREDEFINED_PROVIDERS):
|
|
318
|
+
logger.warning(
|
|
319
|
+
"🚨 You are trying to use a default model provider but your API keys are missing."
|
|
320
|
+
"\n\t\t\tSet the API key for the default providers you intend to use and re-initialize the Data Designer object."
|
|
321
|
+
"\n\t\t\tAlternatively, you can provide your own model providers during Data Designer object initialization."
|
|
322
|
+
"\n\t\t\tSee https://nvidia-nemo.github.io/DataDesigner/models/model-providers/ for more information."
|
|
323
|
+
)
|
|
324
|
+
self._get_interface_info(model_providers).display(InfoType.MODEL_PROVIDERS)
|
|
325
|
+
return model_providers
|
|
326
|
+
return model_providers or []
|
|
327
|
+
|
|
310
328
|
def _create_dataset_builder(
|
|
311
329
|
self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
|
|
312
330
|
) -> ColumnWiseDatasetBuilder:
|
|
@@ -349,3 +367,6 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
349
367
|
)
|
|
350
368
|
),
|
|
351
369
|
)
|
|
370
|
+
|
|
371
|
+
def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:
|
|
372
|
+
return InterfaceInfo(model_providers=model_providers)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: data-designer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: General framework for synthetic data generation
|
|
5
|
+
License-Expression: Apache-2.0
|
|
5
6
|
License-File: LICENSE
|
|
6
7
|
Classifier: Development Status :: 4 - Beta
|
|
7
8
|
Classifier: Intended Audience :: Developers
|
|
8
9
|
Classifier: Intended Audience :: Science/Research
|
|
9
|
-
Classifier: License ::
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.13
|
|
14
15
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
-
Classifier: Topic :: Scientific/Engineering :: Human Machine Interfaces
|
|
16
16
|
Classifier: Topic :: Software Development
|
|
17
17
|
Requires-Python: >=3.10
|
|
18
18
|
Requires-Dist: anyascii<1.0,>=0.3.3
|
|
@@ -51,7 +51,7 @@ Description-Content-Type: text/markdown
|
|
|
51
51
|
|
|
52
52
|
[](https://github.com/NVIDIA-NeMo/DataDesigner/actions/workflows/ci.yml)
|
|
53
53
|
[](https://opensource.org/licenses/Apache-2.0)
|
|
54
|
-
[](https://www.python.org/downloads/) [](https://docs.nvidia.com/nemo/microservices/latest/index.html) [](https://nvidia-nemo.github.io/DataDesigner/)
|
|
55
55
|
|
|
56
56
|
**Generate high-quality synthetic datasets from scratch or using your own seed data.**
|
|
57
57
|
|
|
@@ -97,8 +97,7 @@ export NVIDIA_API_KEY="your-api-key-here"
|
|
|
97
97
|
export OPENAI_API_KEY="your-openai-api-key-here"
|
|
98
98
|
```
|
|
99
99
|
|
|
100
|
-
### 3.
|
|
101
|
-
|
|
100
|
+
### 3. Start generating data!
|
|
102
101
|
```python
|
|
103
102
|
from data_designer.essentials import (
|
|
104
103
|
CategorySamplerParams,
|
|
@@ -139,18 +138,18 @@ preview = data_designer.preview(config_builder=config_builder)
|
|
|
139
138
|
preview.display_sample_record()
|
|
140
139
|
```
|
|
141
140
|
|
|
142
|
-
**That's it!** You've created a dataset.
|
|
143
|
-
|
|
144
141
|
---
|
|
145
142
|
|
|
146
143
|
## What's next?
|
|
147
144
|
|
|
148
145
|
### 📚 Learn more
|
|
149
146
|
|
|
150
|
-
- **[Quick Start Guide](https://nvidia-nemo.github.io/DataDesigner)** – Detailed walkthrough with more examples
|
|
151
|
-
- **[Tutorial Notebooks](https://nvidia-nemo.github.io/DataDesigner/notebooks/
|
|
147
|
+
- **[Quick Start Guide](https://nvidia-nemo.github.io/DataDesigner/quick-start/)** – Detailed walkthrough with more examples
|
|
148
|
+
- **[Tutorial Notebooks](https://nvidia-nemo.github.io/DataDesigner/notebooks/)** – Step-by-step interactive tutorials
|
|
152
149
|
- **[Column Types](https://nvidia-nemo.github.io/DataDesigner/concepts/columns/)** – Explore samplers, LLM columns, validators, and more
|
|
150
|
+
- **[Validators](https://nvidia-nemo.github.io/DataDesigner/concepts/validators/)** – Learn how to validate generated data with Python, SQL, and remote validators
|
|
153
151
|
- **[Model Configuration](https://nvidia-nemo.github.io/DataDesigner/models/model-configs/)** – Configure custom models and providers
|
|
152
|
+
- **[Person Sampling](https://nvidia-nemo.github.io/DataDesigner/concepts/person_sampling/)** – Learn how to sample realistic person data with demographic attributes
|
|
154
153
|
|
|
155
154
|
### 🔧 Configure models via CLI
|
|
156
155
|
|
|
@@ -162,12 +161,27 @@ data-designer config list # View current settings
|
|
|
162
161
|
|
|
163
162
|
### 🤝 Get involved
|
|
164
163
|
|
|
165
|
-
- **[Contributing Guide](https://nvidia-nemo.github.io/DataDesigner/CONTRIBUTING
|
|
166
|
-
- **[GitHub Issues](https://github.com/NVIDIA-NeMo/DataDesigner/issues)** – Report bugs or request
|
|
167
|
-
- **[GitHub Discussions](https://github.com/NVIDIA-NeMo/DataDesigner/discussions)** – Ask questions and share ideas
|
|
164
|
+
- **[Contributing Guide](https://nvidia-nemo.github.io/DataDesigner/CONTRIBUTING)** – Help improve Data Designer
|
|
165
|
+
- **[GitHub Issues](https://github.com/NVIDIA-NeMo/DataDesigner/issues)** – Report bugs or make a feature request
|
|
168
166
|
|
|
169
167
|
---
|
|
170
168
|
|
|
171
169
|
## License
|
|
172
170
|
|
|
173
171
|
Apache License 2.0 – see [LICENSE](LICENSE) for details.
|
|
172
|
+
|
|
173
|
+
---
|
|
174
|
+
|
|
175
|
+
## Citation
|
|
176
|
+
|
|
177
|
+
If you use NeMo Data Designer in your research, please cite it using the following BibTeX entry:
|
|
178
|
+
|
|
179
|
+
```bibtex
|
|
180
|
+
@misc{nemo-data-designer,
|
|
181
|
+
author = {The NeMo Data Designer Team},
|
|
182
|
+
title = {NeMo Data Designer: A framework for generating synthetic data from scratch or based on your own seed data},
|
|
183
|
+
howpublished = {\url{https://github.com/NVIDIA-NeMo/DataDesigner}},
|
|
184
|
+
year = {2025},
|
|
185
|
+
note = {GitHub Repository},
|
|
186
|
+
}
|
|
187
|
+
```
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
data_designer/__init__.py,sha256=iCeqRnb640RrL2QpA630GY5Ng7JiDt83Vq0DwLnNugU,461
|
|
2
|
-
data_designer/_version.py,sha256=
|
|
2
|
+
data_designer/_version.py,sha256=Ok5oAXdWgR9aghaFXTafTeDW6sYO3uVe6d2Nket57R4,704
|
|
3
3
|
data_designer/errors.py,sha256=Z4eN9XwzZvGRdBluSNoSqQYkPPzNQIDf0ET_OqWRZh8,179
|
|
4
4
|
data_designer/logging.py,sha256=O6LlQRj4IdkvEEYiMkKfMb_ZDgN1YpkGQUCqcp7nY6w,5354
|
|
5
5
|
data_designer/plugin_manager.py,sha256=jWoo80x0oCiOIJMA43t-vK-_hVv9_xt4WhBcurYoDqw,3098
|
|
@@ -31,20 +31,20 @@ data_designer/cli/services/model_service.py,sha256=Fn3c0qMZqFAEqzBr0haLjp-nLKAkk
|
|
|
31
31
|
data_designer/cli/services/provider_service.py,sha256=pdD2_C4yK0YBabcuan95H86UreZJ5zWFGI3Ue99mXXo,3916
|
|
32
32
|
data_designer/config/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
33
33
|
data_designer/config/base.py,sha256=xCbvwxXKRityWqeGP4zTXVuPHAOoUdpuQr8_t8vY8f8,2423
|
|
34
|
-
data_designer/config/column_configs.py,sha256=
|
|
34
|
+
data_designer/config/column_configs.py,sha256=ixpanQApbn4LUyW7E4IJefXQG6c0eYbGxF-GGwV1xCg,18000
|
|
35
35
|
data_designer/config/column_types.py,sha256=V0Ijwb-asYOX-GQyG9W-X_A-FIbFSajKuus58sG8CSM,6774
|
|
36
36
|
data_designer/config/config_builder.py,sha256=NlAe6cwN6IAE90A8uPLsOdABmmYyUt6UnGYZwgmf_xE,27288
|
|
37
37
|
data_designer/config/data_designer_config.py,sha256=cvIXMVQzYn9vC4GINPz972pDBmt-HrV5dvw1568LVmE,1719
|
|
38
38
|
data_designer/config/dataset_builders.py,sha256=1pNFy_pkQ5lJ6AVZ43AeTuSbz6yC_l7Ndcyp5yaT8hQ,327
|
|
39
|
-
data_designer/config/datastore.py,sha256=
|
|
40
|
-
data_designer/config/default_model_settings.py,sha256=
|
|
39
|
+
data_designer/config/datastore.py,sha256=Ra6MsPCK6Q1Y8JbTQGRrKtyceig1s41ishyKSZoxgno,7572
|
|
40
|
+
data_designer/config/default_model_settings.py,sha256=aMud_RrRStHnDSbwLxU3BnmIu08YtB1-EG6UUY9NedI,4517
|
|
41
41
|
data_designer/config/errors.py,sha256=XneHH6tKHG2sZ71HzmPr7k3UBZ_psnSANknT30n-aa8,449
|
|
42
42
|
data_designer/config/interface.py,sha256=2_tHvxtKAv0C5L7K4ztm-Xa1A-u9Njlwo2drdPa2qmk,1499
|
|
43
43
|
data_designer/config/models.py,sha256=5Cy55BnKYyr-I1UHLUTqZxe6Ca9uVQWpUiwt9X0ZlrU,7521
|
|
44
44
|
data_designer/config/preview_results.py,sha256=H6ETFI6L1TW8MEC9KYsJ1tXGIC5cloCggBCCZd6jiEE,1087
|
|
45
45
|
data_designer/config/processors.py,sha256=qOF_plBoh6UEFNwUpyDgkqIuSDUaSM2S7k-kSAEB5p8,1328
|
|
46
46
|
data_designer/config/sampler_constraints.py,sha256=4JxP-nge5KstqtctJnVg5RLM1w9mA7qFi_BjgTJl9CE,1167
|
|
47
|
-
data_designer/config/sampler_params.py,sha256=
|
|
47
|
+
data_designer/config/sampler_params.py,sha256=NCm2uWEzFHjz8ZzSmiKcVp5jI5okp53tq9l-bWBm4FQ,26821
|
|
48
48
|
data_designer/config/seed.py,sha256=g-iUToYSIFuTv3sbwSG_dF-9RwC8r8AvCD-vS8c_jDg,5487
|
|
49
49
|
data_designer/config/validator_params.py,sha256=sNxFIF2bk_N4jJD-aMH1N5MQynDip08AoMI1ajxtRdc,3909
|
|
50
50
|
data_designer/config/analysis/column_profilers.py,sha256=Qss9gr7oHNcjijW_MMIX9JkFX-V9v5vPwYWCnxLjMDY,2749
|
|
@@ -133,7 +133,7 @@ data_designer/engine/resources/managed_dataset_generator.py,sha256=KXrWdgod-NFaC
|
|
|
133
133
|
data_designer/engine/resources/managed_dataset_repository.py,sha256=lqVxuoCxc07QTrhnAR1mgDiHFkzjjkx2IwcrxrdbloY,7547
|
|
134
134
|
data_designer/engine/resources/managed_storage.py,sha256=jRnGeCTGlu6FxC6tOCssPiSpbHEf0mbqFfm3mM0utdA,2079
|
|
135
135
|
data_designer/engine/resources/resource_provider.py,sha256=CbB2D538ECGkvyHF1V63_TDn-wStCoklV7bF0y4mabY,1859
|
|
136
|
-
data_designer/engine/resources/seed_dataset_data_store.py,sha256=
|
|
136
|
+
data_designer/engine/resources/seed_dataset_data_store.py,sha256=dM2HgfyUgbF7MidN8dn5S-LAR0GVPJfjqXpDPTP2XoA,3035
|
|
137
137
|
data_designer/engine/sampling_gen/column.py,sha256=gDIPth7vK2797rGtLhf_kVGMAC-khefKHodeeDoqV-I,3946
|
|
138
138
|
data_designer/engine/sampling_gen/constraints.py,sha256=RyhRF9KeUOwEiHr_TN3QwLWOVLTpuCFpCI_3Qr-9Whs,3028
|
|
139
139
|
data_designer/engine/sampling_gen/errors.py,sha256=UBZBtosD07EisCdeo8r-Uq4h0QL3tYS1qwtEmca8_jM,828
|
|
@@ -163,15 +163,15 @@ data_designer/engine/validators/remote.py,sha256=jtDIvWzfHh17m2ac_Fp93p49Th8RlkB
|
|
|
163
163
|
data_designer/engine/validators/sql.py,sha256=bxbyxPxDT9yuwjhABVEY40iR1pzWRFi65WU4tPgG2bE,2250
|
|
164
164
|
data_designer/essentials/__init__.py,sha256=zrDZ7hahOmOhCPdfoj0z9ALN10lXIesfwd2qXRqTcdY,4125
|
|
165
165
|
data_designer/interface/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
166
|
-
data_designer/interface/data_designer.py,sha256=
|
|
166
|
+
data_designer/interface/data_designer.py,sha256=EzOT_kkWXm9-1Zgbj4RvBfV6_r5ABR7mOuNwbgvKKLQ,16273
|
|
167
167
|
data_designer/interface/errors.py,sha256=jagKT3tPUnYq4e3e6AkTnBkcayHyEfxjPMBzx-GEKe4,565
|
|
168
168
|
data_designer/interface/results.py,sha256=qFxa8SuCXeADiRpaCMBwJcExkJBCfUPeGCdcJSTjoTc,2111
|
|
169
169
|
data_designer/plugins/__init__.py,sha256=c_V7q4QhfVoNf_uc9UwmXCsWqwtyWogI7YoN_0PzzE4,234
|
|
170
170
|
data_designer/plugins/errors.py,sha256=yPIHpSddEr-o9ZcNVibb2hI-73O15Kg_Od8SlmQlnRs,297
|
|
171
171
|
data_designer/plugins/plugin.py,sha256=7ErdUyrTdOb5PCBE3msdhTOrvQpldjOQw90-Bu4Bosc,2522
|
|
172
172
|
data_designer/plugins/registry.py,sha256=iPDTh4duV1cKt7H1fXkj1bKLG6SyUKmzQ9xh-vjEoaM,3018
|
|
173
|
-
data_designer-0.1.
|
|
174
|
-
data_designer-0.1.
|
|
175
|
-
data_designer-0.1.
|
|
176
|
-
data_designer-0.1.
|
|
177
|
-
data_designer-0.1.
|
|
173
|
+
data_designer-0.1.2.dist-info/METADATA,sha256=PjPyL9UQ0Ys4XPqRuruAjuUJ6XPMDf1n1bz17wwoct4,6644
|
|
174
|
+
data_designer-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
175
|
+
data_designer-0.1.2.dist-info/entry_points.txt,sha256=NWWWidyDxN6CYX6y664PhBYMhbaYTQTyprqfYAgkyCg,57
|
|
176
|
+
data_designer-0.1.2.dist-info/licenses/LICENSE,sha256=cSWJDwVqHyQgly8Zmt3pqXJ2eQbZVYwN9qd0NMssxXY,11336
|
|
177
|
+
data_designer-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|