data-designer 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/config/column_configs.py +29 -4
- data_designer/config/datastore.py +70 -34
- data_designer/config/default_model_settings.py +1 -1
- data_designer/config/sampler_params.py +16 -5
- data_designer/engine/dataset_builders/artifact_storage.py +15 -1
- data_designer/engine/dataset_builders/column_wise_builder.py +2 -2
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +31 -9
- data_designer/interface/data_designer.py +7 -3
- {data_designer-0.1.1.dist-info → data_designer-0.1.3.dist-info}/METADATA +3 -6
- {data_designer-0.1.1.dist-info → data_designer-0.1.3.dist-info}/RECORD +14 -14
- {data_designer-0.1.1.dist-info → data_designer-0.1.3.dist-info}/WHEEL +1 -1
- {data_designer-0.1.1.dist-info → data_designer-0.1.3.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.1.dist-info → data_designer-0.1.3.dist-info}/licenses/LICENSE +0 -0
data_designer/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.3'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 3)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC
|
|
5
|
-
from typing import Literal, Optional, Type, Union
|
|
5
|
+
from typing import Annotated, Literal, Optional, Type, Union
|
|
6
6
|
|
|
7
|
-
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
from pydantic import BaseModel, Discriminator, Field, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
10
|
from .base import ConfigBase
|
|
@@ -89,11 +89,36 @@ class SamplerColumnConfig(SingleColumnConfig):
|
|
|
89
89
|
"""
|
|
90
90
|
|
|
91
91
|
sampler_type: SamplerType
|
|
92
|
-
params: SamplerParamsT
|
|
93
|
-
conditional_params: dict[str, SamplerParamsT] = {}
|
|
92
|
+
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
|
|
93
|
+
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
|
|
94
94
|
convert_to: Optional[str] = None
|
|
95
95
|
column_type: Literal["sampler"] = "sampler"
|
|
96
96
|
|
|
97
|
+
@model_validator(mode="before")
|
|
98
|
+
@classmethod
|
|
99
|
+
def inject_sampler_type_into_params(cls, data: dict) -> dict:
|
|
100
|
+
"""Inject sampler_type into params dict to enable discriminated union resolution.
|
|
101
|
+
|
|
102
|
+
This allows users to pass params as a simple dict without the sampler_type field,
|
|
103
|
+
which will be automatically added based on the outer sampler_type field.
|
|
104
|
+
"""
|
|
105
|
+
if isinstance(data, dict):
|
|
106
|
+
sampler_type = data.get("sampler_type")
|
|
107
|
+
params = data.get("params")
|
|
108
|
+
|
|
109
|
+
# If params is a dict and doesn't have sampler_type, inject it
|
|
110
|
+
if sampler_type and isinstance(params, dict) and "sampler_type" not in params:
|
|
111
|
+
data["params"] = {"sampler_type": sampler_type, **params}
|
|
112
|
+
|
|
113
|
+
# Handle conditional_params similarly
|
|
114
|
+
conditional_params = data.get("conditional_params")
|
|
115
|
+
if conditional_params and isinstance(conditional_params, dict):
|
|
116
|
+
for condition, cond_params in conditional_params.items():
|
|
117
|
+
if isinstance(cond_params, dict) and "sampler_type" not in cond_params:
|
|
118
|
+
data["conditional_params"][condition] = {"sampler_type": sampler_type, **cond_params}
|
|
119
|
+
|
|
120
|
+
return data
|
|
121
|
+
|
|
97
122
|
|
|
98
123
|
class LLMTextColumnConfig(SingleColumnConfig):
|
|
99
124
|
"""Configuration for text generation columns using Large Language Models.
|
|
@@ -31,34 +31,37 @@ class DatastoreSettings(BaseModel):
|
|
|
31
31
|
token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def get_file_column_names(
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
41
|
-
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
42
|
-
file_path = matching_files[0]
|
|
34
|
+
def get_file_column_names(file_reference: Union[str, Path, HfFileSystem], file_type: str) -> list[str]:
|
|
35
|
+
"""Get column names from a dataset file.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
file_reference: Path to the dataset file, or an HfFileSystem object.
|
|
39
|
+
file_type: Type of the dataset file. Must be one of: 'parquet', 'json', 'jsonl', 'csv'.
|
|
43
40
|
|
|
41
|
+
Raises:
|
|
42
|
+
InvalidFilePathError: If the file type is not supported.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of column names.
|
|
46
|
+
"""
|
|
44
47
|
if file_type == "parquet":
|
|
45
48
|
try:
|
|
46
|
-
schema = pq.read_schema(
|
|
49
|
+
schema = pq.read_schema(file_reference)
|
|
47
50
|
if hasattr(schema, "names"):
|
|
48
51
|
return schema.names
|
|
49
52
|
else:
|
|
50
53
|
return [field.name for field in schema]
|
|
51
54
|
except Exception as e:
|
|
52
|
-
logger.warning(f"Failed to process parquet file {
|
|
55
|
+
logger.warning(f"Failed to process parquet file {file_reference}: {e}")
|
|
53
56
|
return []
|
|
54
57
|
elif file_type in ["json", "jsonl"]:
|
|
55
|
-
return pd.read_json(
|
|
58
|
+
return pd.read_json(file_reference, orient="records", lines=True, nrows=1).columns.tolist()
|
|
56
59
|
elif file_type == "csv":
|
|
57
60
|
try:
|
|
58
|
-
df = pd.read_csv(
|
|
61
|
+
df = pd.read_csv(file_reference, nrows=1)
|
|
59
62
|
return df.columns.tolist()
|
|
60
63
|
except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:
|
|
61
|
-
logger.warning(f"Failed to process CSV file {
|
|
64
|
+
logger.warning(f"Failed to process CSV file {file_reference}: {e}")
|
|
62
65
|
return []
|
|
63
66
|
else:
|
|
64
67
|
raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}")
|
|
@@ -66,12 +69,36 @@ def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[s
|
|
|
66
69
|
|
|
67
70
|
def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]:
|
|
68
71
|
if hasattr(seed_dataset_reference, "datastore_settings"):
|
|
69
|
-
return
|
|
72
|
+
return fetch_seed_dataset_column_names_from_datastore(
|
|
70
73
|
seed_dataset_reference.repo_id,
|
|
71
74
|
seed_dataset_reference.filename,
|
|
72
75
|
seed_dataset_reference.datastore_settings,
|
|
73
76
|
)
|
|
74
|
-
return
|
|
77
|
+
return fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def fetch_seed_dataset_column_names_from_datastore(
|
|
81
|
+
repo_id: str,
|
|
82
|
+
filename: str,
|
|
83
|
+
datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
|
|
84
|
+
) -> list[str]:
|
|
85
|
+
file_type = filename.split(".")[-1]
|
|
86
|
+
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
|
|
87
|
+
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
|
|
88
|
+
|
|
89
|
+
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
90
|
+
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True)
|
|
91
|
+
|
|
92
|
+
file_path = _extract_single_file_path_from_glob_pattern_if_present(f"datasets/{repo_id}/{filename}", fs=fs)
|
|
93
|
+
|
|
94
|
+
with fs.open(file_path) as f:
|
|
95
|
+
return get_file_column_names(f, file_type)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
|
|
99
|
+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
|
|
100
|
+
dataset_path = _extract_single_file_path_from_glob_pattern_if_present(dataset_path)
|
|
101
|
+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
|
|
75
102
|
|
|
76
103
|
|
|
77
104
|
def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings:
|
|
@@ -114,25 +141,34 @@ def upload_to_hf_hub(
|
|
|
114
141
|
return f"{repo_id}/{filename}"
|
|
115
142
|
|
|
116
143
|
|
|
117
|
-
def
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
file_type = filename.split(".")[-1]
|
|
123
|
-
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
|
|
124
|
-
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
|
|
125
|
-
|
|
126
|
-
datastore_settings = resolve_datastore_settings(datastore_settings)
|
|
127
|
-
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True)
|
|
128
|
-
|
|
129
|
-
with fs.open(f"datasets/{repo_id}/{filename}") as f:
|
|
130
|
-
return get_file_column_names(f, file_type)
|
|
131
|
-
|
|
144
|
+
def _extract_single_file_path_from_glob_pattern_if_present(
|
|
145
|
+
file_path: str | Path,
|
|
146
|
+
fs: HfFileSystem | None = None,
|
|
147
|
+
) -> Path:
|
|
148
|
+
file_path = Path(file_path)
|
|
132
149
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
150
|
+
# no glob pattern
|
|
151
|
+
if "*" not in str(file_path):
|
|
152
|
+
return file_path
|
|
153
|
+
|
|
154
|
+
# glob pattern with HfFileSystem
|
|
155
|
+
if fs is not None:
|
|
156
|
+
file_to_check = None
|
|
157
|
+
file_extension = file_path.name.split(".")[-1]
|
|
158
|
+
for file in fs.ls(str(file_path.parent)):
|
|
159
|
+
filename = file["name"]
|
|
160
|
+
if filename.endswith(f".{file_extension}"):
|
|
161
|
+
file_to_check = filename
|
|
162
|
+
if file_to_check is None:
|
|
163
|
+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
164
|
+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
165
|
+
return Path(file_to_check)
|
|
166
|
+
|
|
167
|
+
# glob pattern with local file system
|
|
168
|
+
if not (matching_files := sorted(file_path.parent.glob(file_path.name))):
|
|
169
|
+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
|
|
170
|
+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
|
|
171
|
+
return matching_files[0]
|
|
136
172
|
|
|
137
173
|
|
|
138
174
|
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
|
|
@@ -78,7 +78,7 @@ def get_default_model_configs() -> list[ModelConfig]:
|
|
|
78
78
|
return []
|
|
79
79
|
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def get_default_model_providers_missing_api_keys() -> list[str]:
|
|
82
82
|
missing_api_keys = []
|
|
83
83
|
for predefined_provider in PREDEFINED_PROVIDERS:
|
|
84
84
|
if os.environ.get(predefined_provider["api_key"]) is None:
|
|
@@ -66,6 +66,7 @@ class CategorySamplerParams(ConfigBase):
|
|
|
66
66
|
"Larger values will be sampled with higher probability."
|
|
67
67
|
),
|
|
68
68
|
)
|
|
69
|
+
sampler_type: Literal[SamplerType.CATEGORY] = SamplerType.CATEGORY
|
|
69
70
|
|
|
70
71
|
@model_validator(mode="after")
|
|
71
72
|
def _normalize_weights_if_needed(self) -> Self:
|
|
@@ -106,6 +107,7 @@ class DatetimeSamplerParams(ConfigBase):
|
|
|
106
107
|
default="D",
|
|
107
108
|
description="Sampling units, e.g. the smallest possible time interval between samples.",
|
|
108
109
|
)
|
|
110
|
+
sampler_type: Literal[SamplerType.DATETIME] = SamplerType.DATETIME
|
|
109
111
|
|
|
110
112
|
@field_validator("start", "end")
|
|
111
113
|
@classmethod
|
|
@@ -136,6 +138,7 @@ class SubcategorySamplerParams(ConfigBase):
|
|
|
136
138
|
...,
|
|
137
139
|
description="Mapping from each value of parent category to a list of subcategory values.",
|
|
138
140
|
)
|
|
141
|
+
sampler_type: Literal[SamplerType.SUBCATEGORY] = SamplerType.SUBCATEGORY
|
|
139
142
|
|
|
140
143
|
|
|
141
144
|
class TimeDeltaSamplerParams(ConfigBase):
|
|
@@ -187,6 +190,7 @@ class TimeDeltaSamplerParams(ConfigBase):
|
|
|
187
190
|
default="D",
|
|
188
191
|
description="Sampling units, e.g. the smallest possible time interval between samples.",
|
|
189
192
|
)
|
|
193
|
+
sampler_type: Literal[SamplerType.TIMEDELTA] = SamplerType.TIMEDELTA
|
|
190
194
|
|
|
191
195
|
@model_validator(mode="after")
|
|
192
196
|
def _validate_min_less_than_max(self) -> Self:
|
|
@@ -219,6 +223,7 @@ class UUIDSamplerParams(ConfigBase):
|
|
|
219
223
|
default=False,
|
|
220
224
|
description="If true, all letters in the UUID will be capitalized.",
|
|
221
225
|
)
|
|
226
|
+
sampler_type: Literal[SamplerType.UUID] = SamplerType.UUID
|
|
222
227
|
|
|
223
228
|
@property
|
|
224
229
|
def last_index(self) -> int:
|
|
@@ -257,6 +262,7 @@ class ScipySamplerParams(ConfigBase):
|
|
|
257
262
|
decimal_places: Optional[int] = Field(
|
|
258
263
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
259
264
|
)
|
|
265
|
+
sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
|
|
260
266
|
|
|
261
267
|
|
|
262
268
|
class BinomialSamplerParams(ConfigBase):
|
|
@@ -273,6 +279,7 @@ class BinomialSamplerParams(ConfigBase):
|
|
|
273
279
|
|
|
274
280
|
n: int = Field(..., description="Number of trials.")
|
|
275
281
|
p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0)
|
|
282
|
+
sampler_type: Literal[SamplerType.BINOMIAL] = SamplerType.BINOMIAL
|
|
276
283
|
|
|
277
284
|
|
|
278
285
|
class BernoulliSamplerParams(ConfigBase):
|
|
@@ -288,6 +295,7 @@ class BernoulliSamplerParams(ConfigBase):
|
|
|
288
295
|
"""
|
|
289
296
|
|
|
290
297
|
p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0)
|
|
298
|
+
sampler_type: Literal[SamplerType.BERNOULLI] = SamplerType.BERNOULLI
|
|
291
299
|
|
|
292
300
|
|
|
293
301
|
class BernoulliMixtureSamplerParams(ConfigBase):
|
|
@@ -327,6 +335,7 @@ class BernoulliMixtureSamplerParams(ConfigBase):
|
|
|
327
335
|
...,
|
|
328
336
|
description="Parameters of the scipy.stats distribution given in `dist_name`.",
|
|
329
337
|
)
|
|
338
|
+
sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = SamplerType.BERNOULLI_MIXTURE
|
|
330
339
|
|
|
331
340
|
|
|
332
341
|
class GaussianSamplerParams(ConfigBase):
|
|
@@ -350,6 +359,7 @@ class GaussianSamplerParams(ConfigBase):
|
|
|
350
359
|
decimal_places: Optional[int] = Field(
|
|
351
360
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
352
361
|
)
|
|
362
|
+
sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
|
|
353
363
|
|
|
354
364
|
|
|
355
365
|
class PoissonSamplerParams(ConfigBase):
|
|
@@ -369,6 +379,7 @@ class PoissonSamplerParams(ConfigBase):
|
|
|
369
379
|
"""
|
|
370
380
|
|
|
371
381
|
mean: float = Field(..., description="Mean number of events in a fixed interval.")
|
|
382
|
+
sampler_type: Literal[SamplerType.POISSON] = SamplerType.POISSON
|
|
372
383
|
|
|
373
384
|
|
|
374
385
|
class UniformSamplerParams(ConfigBase):
|
|
@@ -390,6 +401,7 @@ class UniformSamplerParams(ConfigBase):
|
|
|
390
401
|
decimal_places: Optional[int] = Field(
|
|
391
402
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
392
403
|
)
|
|
404
|
+
sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
|
|
393
405
|
|
|
394
406
|
|
|
395
407
|
#########################################
|
|
@@ -418,9 +430,6 @@ class PersonSamplerParams(ConfigBase):
|
|
|
418
430
|
age_range: Two-element list [min_age, max_age] specifying the age range to sample from
|
|
419
431
|
(inclusive). Defaults to a standard age range. Both values must be between minimum and
|
|
420
432
|
maximum allowed ages.
|
|
421
|
-
state: Only supported for "en_US" locale. Filters to sample people from specified US state(s).
|
|
422
|
-
Must be provided as two-letter state abbreviations (e.g., "CA", "NY", "TX"). Can be a
|
|
423
|
-
single state or a list of states.
|
|
424
433
|
with_synthetic_personas: If True, appends additional synthetic persona columns including
|
|
425
434
|
personality traits, interests, and background descriptions. Only supported for certain
|
|
426
435
|
locales with managed datasets.
|
|
@@ -470,11 +479,12 @@ class PersonSamplerParams(ConfigBase):
|
|
|
470
479
|
default=False,
|
|
471
480
|
description="If True, then append synthetic persona columns to each generated person.",
|
|
472
481
|
)
|
|
482
|
+
sampler_type: Literal[SamplerType.PERSON] = SamplerType.PERSON
|
|
473
483
|
|
|
474
484
|
@property
|
|
475
485
|
def generator_kwargs(self) -> list[str]:
|
|
476
486
|
"""Keyword arguments to pass to the person generator."""
|
|
477
|
-
return [f for f in list(PersonSamplerParams.model_fields) if f
|
|
487
|
+
return [f for f in list(PersonSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
|
|
478
488
|
|
|
479
489
|
@property
|
|
480
490
|
def people_gen_key(self) -> str:
|
|
@@ -533,11 +543,12 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
533
543
|
min_length=2,
|
|
534
544
|
max_length=2,
|
|
535
545
|
)
|
|
546
|
+
sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = SamplerType.PERSON_FROM_FAKER
|
|
536
547
|
|
|
537
548
|
@property
|
|
538
549
|
def generator_kwargs(self) -> list[str]:
|
|
539
550
|
"""Keyword arguments to pass to the person generator."""
|
|
540
|
-
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f
|
|
551
|
+
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f not in ("locale", "sampler_type")]
|
|
541
552
|
|
|
542
553
|
@property
|
|
543
554
|
def people_gen_key(self) -> str:
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from functools import cached_property
|
|
4
6
|
import json
|
|
5
7
|
import logging
|
|
6
8
|
from pathlib import Path
|
|
@@ -36,9 +38,21 @@ class ArtifactStorage(BaseModel):
|
|
|
36
38
|
def artifact_path_exists(self) -> bool:
|
|
37
39
|
return self.artifact_path.exists()
|
|
38
40
|
|
|
41
|
+
@cached_property
|
|
42
|
+
def resolved_dataset_name(self) -> str:
|
|
43
|
+
dataset_path = self.artifact_path / self.dataset_name
|
|
44
|
+
if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
|
|
45
|
+
new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}"
|
|
46
|
+
logger.info(
|
|
47
|
+
f"📂 Dataset path {str(dataset_path)!r} already exists. Dataset from this session"
|
|
48
|
+
f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead."
|
|
49
|
+
)
|
|
50
|
+
return new_dataset_name
|
|
51
|
+
return self.dataset_name
|
|
52
|
+
|
|
39
53
|
@property
|
|
40
54
|
def base_dataset_path(self) -> Path:
|
|
41
|
-
return self.artifact_path / self.
|
|
55
|
+
return self.artifact_path / self.resolved_dataset_name
|
|
42
56
|
|
|
43
57
|
@property
|
|
44
58
|
def dropped_columns_dataset_path(self) -> Path:
|
|
@@ -88,8 +88,8 @@ class ColumnWiseDatasetBuilder:
|
|
|
88
88
|
start_time = time.perf_counter()
|
|
89
89
|
|
|
90
90
|
self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
|
|
91
|
-
for batch_idx in range(
|
|
92
|
-
logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}")
|
|
91
|
+
for batch_idx in range(self.batch_manager.num_batches):
|
|
92
|
+
logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
|
|
93
93
|
self._run_batch(generators)
|
|
94
94
|
df_batch = self._run_processors(
|
|
95
95
|
stage=BuildStage.POST_BATCH,
|
|
@@ -14,6 +14,7 @@ REQUIRED_FIELDS = {"first_name", "last_name", "age", "locale"}
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
PII_FIELDS = [
|
|
17
|
+
# Core demographic fields
|
|
17
18
|
"uuid",
|
|
18
19
|
"first_name",
|
|
19
20
|
"middle_name",
|
|
@@ -22,25 +23,38 @@ PII_FIELDS = [
|
|
|
22
23
|
"age",
|
|
23
24
|
"birth_date",
|
|
24
25
|
"marital_status",
|
|
25
|
-
"street_name",
|
|
26
|
-
"street_number",
|
|
27
|
-
"unit",
|
|
28
26
|
"postcode",
|
|
29
|
-
"region",
|
|
30
27
|
"city",
|
|
31
|
-
"
|
|
28
|
+
"region",
|
|
32
29
|
"country",
|
|
33
|
-
"
|
|
34
|
-
"zone",
|
|
30
|
+
"locale",
|
|
35
31
|
"bachelors_field",
|
|
36
|
-
"education_degree",
|
|
37
32
|
"education_level",
|
|
38
33
|
"occupation",
|
|
39
|
-
"
|
|
34
|
+
"national_id",
|
|
35
|
+
# US-specific fields
|
|
36
|
+
"street_name",
|
|
37
|
+
"street_number",
|
|
38
|
+
"unit",
|
|
39
|
+
"state",
|
|
40
|
+
"email_address",
|
|
41
|
+
"phone_number",
|
|
42
|
+
# Japan-specific fields
|
|
43
|
+
"area",
|
|
44
|
+
"prefecture",
|
|
45
|
+
"zone",
|
|
46
|
+
# India-specific fields
|
|
47
|
+
"district",
|
|
48
|
+
"religion",
|
|
49
|
+
"education_degree",
|
|
50
|
+
"first_language",
|
|
51
|
+
"second_language",
|
|
52
|
+
"third_language",
|
|
40
53
|
]
|
|
41
54
|
|
|
42
55
|
|
|
43
56
|
PERSONA_FIELDS = [
|
|
57
|
+
# Core persona fields
|
|
44
58
|
"persona",
|
|
45
59
|
"career_goals_and_ambitions",
|
|
46
60
|
"arts_persona",
|
|
@@ -61,4 +75,12 @@ PERSONA_FIELDS = [
|
|
|
61
75
|
"extraversion",
|
|
62
76
|
"agreeableness",
|
|
63
77
|
"neuroticism",
|
|
78
|
+
# Japan-specific persona fields
|
|
79
|
+
"aspects",
|
|
80
|
+
"digital_skills",
|
|
81
|
+
# India-specific persona fields
|
|
82
|
+
"linguistic_persona",
|
|
83
|
+
"religious_persona",
|
|
84
|
+
"linguistic_background",
|
|
85
|
+
"religious_background",
|
|
64
86
|
]
|
|
@@ -9,8 +9,8 @@ import pandas as pd
|
|
|
9
9
|
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
10
10
|
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
11
11
|
from data_designer.config.default_model_settings import (
|
|
12
|
-
get_defaul_model_providers_missing_api_keys,
|
|
13
12
|
get_default_model_configs,
|
|
13
|
+
get_default_model_providers_missing_api_keys,
|
|
14
14
|
get_default_provider_name,
|
|
15
15
|
get_default_providers,
|
|
16
16
|
resolve_seed_default_model_settings,
|
|
@@ -173,7 +173,11 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
173
173
|
configuration (columns, constraints, seed data, etc.).
|
|
174
174
|
num_records: Number of records to generate.
|
|
175
175
|
dataset_name: Name of the dataset. This name will be used as the dataset
|
|
176
|
-
folder name in the artifact path directory.
|
|
176
|
+
folder name in the artifact path directory. If a non-empty directory with the
|
|
177
|
+
same name already exists, dataset will be saved to a new directory with
|
|
178
|
+
a datetime stamp. For example, if the dataset name is "awesome_dataset" and a directory
|
|
179
|
+
with the same name already exists, the dataset will be saved to a new directory
|
|
180
|
+
with the name "awesome_dataset_2025-01-01_12-00-00".
|
|
177
181
|
|
|
178
182
|
Returns:
|
|
179
183
|
DatasetCreationResults object with methods for loading the generated dataset,
|
|
@@ -313,7 +317,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
313
317
|
if model_providers is None:
|
|
314
318
|
if can_run_data_designer_locally():
|
|
315
319
|
model_providers = get_default_providers()
|
|
316
|
-
missing_api_keys =
|
|
320
|
+
missing_api_keys = get_default_model_providers_missing_api_keys()
|
|
317
321
|
if len(missing_api_keys) == len(PREDEFINED_PROVIDERS):
|
|
318
322
|
logger.warning(
|
|
319
323
|
"🚨 You are trying to use a default model provider but your API keys are missing."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: data-designer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: General framework for synthetic data generation
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
License-File: LICENSE
|
|
@@ -97,8 +97,7 @@ export NVIDIA_API_KEY="your-api-key-here"
|
|
|
97
97
|
export OPENAI_API_KEY="your-openai-api-key-here"
|
|
98
98
|
```
|
|
99
99
|
|
|
100
|
-
### 3.
|
|
101
|
-
|
|
100
|
+
### 3. Start generating data!
|
|
102
101
|
```python
|
|
103
102
|
from data_designer.essentials import (
|
|
104
103
|
CategorySamplerParams,
|
|
@@ -139,8 +138,6 @@ preview = data_designer.preview(config_builder=config_builder)
|
|
|
139
138
|
preview.display_sample_record()
|
|
140
139
|
```
|
|
141
140
|
|
|
142
|
-
**That's it!** You've created a dataset.
|
|
143
|
-
|
|
144
141
|
---
|
|
145
142
|
|
|
146
143
|
## What's next?
|
|
@@ -148,7 +145,7 @@ preview.display_sample_record()
|
|
|
148
145
|
### 📚 Learn more
|
|
149
146
|
|
|
150
147
|
- **[Quick Start Guide](https://nvidia-nemo.github.io/DataDesigner/quick-start/)** – Detailed walkthrough with more examples
|
|
151
|
-
- **[Tutorial Notebooks](https://nvidia-nemo.github.io/DataDesigner/notebooks/
|
|
148
|
+
- **[Tutorial Notebooks](https://nvidia-nemo.github.io/DataDesigner/notebooks/)** – Step-by-step interactive tutorials
|
|
152
149
|
- **[Column Types](https://nvidia-nemo.github.io/DataDesigner/concepts/columns/)** – Explore samplers, LLM columns, validators, and more
|
|
153
150
|
- **[Validators](https://nvidia-nemo.github.io/DataDesigner/concepts/validators/)** – Learn how to validate generated data with Python, SQL, and remote validators
|
|
154
151
|
- **[Model Configuration](https://nvidia-nemo.github.io/DataDesigner/models/model-configs/)** – Configure custom models and providers
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
data_designer/__init__.py,sha256=iCeqRnb640RrL2QpA630GY5Ng7JiDt83Vq0DwLnNugU,461
|
|
2
|
-
data_designer/_version.py,sha256=
|
|
2
|
+
data_designer/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
|
|
3
3
|
data_designer/errors.py,sha256=Z4eN9XwzZvGRdBluSNoSqQYkPPzNQIDf0ET_OqWRZh8,179
|
|
4
4
|
data_designer/logging.py,sha256=O6LlQRj4IdkvEEYiMkKfMb_ZDgN1YpkGQUCqcp7nY6w,5354
|
|
5
5
|
data_designer/plugin_manager.py,sha256=jWoo80x0oCiOIJMA43t-vK-_hVv9_xt4WhBcurYoDqw,3098
|
|
@@ -31,20 +31,20 @@ data_designer/cli/services/model_service.py,sha256=Fn3c0qMZqFAEqzBr0haLjp-nLKAkk
|
|
|
31
31
|
data_designer/cli/services/provider_service.py,sha256=pdD2_C4yK0YBabcuan95H86UreZJ5zWFGI3Ue99mXXo,3916
|
|
32
32
|
data_designer/config/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
33
33
|
data_designer/config/base.py,sha256=xCbvwxXKRityWqeGP4zTXVuPHAOoUdpuQr8_t8vY8f8,2423
|
|
34
|
-
data_designer/config/column_configs.py,sha256=
|
|
34
|
+
data_designer/config/column_configs.py,sha256=ixpanQApbn4LUyW7E4IJefXQG6c0eYbGxF-GGwV1xCg,18000
|
|
35
35
|
data_designer/config/column_types.py,sha256=V0Ijwb-asYOX-GQyG9W-X_A-FIbFSajKuus58sG8CSM,6774
|
|
36
36
|
data_designer/config/config_builder.py,sha256=NlAe6cwN6IAE90A8uPLsOdABmmYyUt6UnGYZwgmf_xE,27288
|
|
37
37
|
data_designer/config/data_designer_config.py,sha256=cvIXMVQzYn9vC4GINPz972pDBmt-HrV5dvw1568LVmE,1719
|
|
38
38
|
data_designer/config/dataset_builders.py,sha256=1pNFy_pkQ5lJ6AVZ43AeTuSbz6yC_l7Ndcyp5yaT8hQ,327
|
|
39
|
-
data_designer/config/datastore.py,sha256=
|
|
40
|
-
data_designer/config/default_model_settings.py,sha256=
|
|
39
|
+
data_designer/config/datastore.py,sha256=Ra6MsPCK6Q1Y8JbTQGRrKtyceig1s41ishyKSZoxgno,7572
|
|
40
|
+
data_designer/config/default_model_settings.py,sha256=aMud_RrRStHnDSbwLxU3BnmIu08YtB1-EG6UUY9NedI,4517
|
|
41
41
|
data_designer/config/errors.py,sha256=XneHH6tKHG2sZ71HzmPr7k3UBZ_psnSANknT30n-aa8,449
|
|
42
42
|
data_designer/config/interface.py,sha256=2_tHvxtKAv0C5L7K4ztm-Xa1A-u9Njlwo2drdPa2qmk,1499
|
|
43
43
|
data_designer/config/models.py,sha256=5Cy55BnKYyr-I1UHLUTqZxe6Ca9uVQWpUiwt9X0ZlrU,7521
|
|
44
44
|
data_designer/config/preview_results.py,sha256=H6ETFI6L1TW8MEC9KYsJ1tXGIC5cloCggBCCZd6jiEE,1087
|
|
45
45
|
data_designer/config/processors.py,sha256=qOF_plBoh6UEFNwUpyDgkqIuSDUaSM2S7k-kSAEB5p8,1328
|
|
46
46
|
data_designer/config/sampler_constraints.py,sha256=4JxP-nge5KstqtctJnVg5RLM1w9mA7qFi_BjgTJl9CE,1167
|
|
47
|
-
data_designer/config/sampler_params.py,sha256=
|
|
47
|
+
data_designer/config/sampler_params.py,sha256=W2GGRwzWZ4RlJAjDpyqSoF6bjpYjT7WHIhS3D0GfupE,26574
|
|
48
48
|
data_designer/config/seed.py,sha256=g-iUToYSIFuTv3sbwSG_dF-9RwC8r8AvCD-vS8c_jDg,5487
|
|
49
49
|
data_designer/config/validator_params.py,sha256=sNxFIF2bk_N4jJD-aMH1N5MQynDip08AoMI1ajxtRdc,3909
|
|
50
50
|
data_designer/config/analysis/column_profilers.py,sha256=Qss9gr7oHNcjijW_MMIX9JkFX-V9v5vPwYWCnxLjMDY,2749
|
|
@@ -87,8 +87,8 @@ data_designer/engine/column_generators/generators/validation.py,sha256=MbDFXzief
|
|
|
87
87
|
data_designer/engine/column_generators/utils/errors.py,sha256=ugNwaqnPdrPZI7YnKLbYwFjYUSm0WAzgaVu_u6i5Rc8,365
|
|
88
88
|
data_designer/engine/column_generators/utils/judge_score_factory.py,sha256=JRoaZgRGK24dH0zx7MNGSccK196tQK_l0sbwNkurg7c,2132
|
|
89
89
|
data_designer/engine/column_generators/utils/prompt_renderer.py,sha256=d4tbyPsgmFDikW3nxL5is9RNaajMkoPDCrfkQkxw7rc,4760
|
|
90
|
-
data_designer/engine/dataset_builders/artifact_storage.py,sha256=
|
|
91
|
-
data_designer/engine/dataset_builders/column_wise_builder.py,sha256=
|
|
90
|
+
data_designer/engine/dataset_builders/artifact_storage.py,sha256=0hpjJ4s3kQ3h-cEpgtIcDpx3UIEMH1FNX5Sp_8yRU9s,7995
|
|
91
|
+
data_designer/engine/dataset_builders/column_wise_builder.py,sha256=bXaFhFD0GsY-9b_GLXY345N0BH5z2YjiWrs_yFDqYgA,13074
|
|
92
92
|
data_designer/engine/dataset_builders/errors.py,sha256=1kChleChG4rASWIiL4Bel6Ox6aFZjQUrh5ogPt1CDWo,359
|
|
93
93
|
data_designer/engine/dataset_builders/multi_column_configs.py,sha256=t28fhI-WRIBohFnAJ80l5EAETEDB5rJ5RSWInMiRfyE,1619
|
|
94
94
|
data_designer/engine/dataset_builders/utils/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
@@ -148,7 +148,7 @@ data_designer/engine/sampling_gen/data_sources/base.py,sha256=BRU9pzDvgB5B1Mgtj8
|
|
|
148
148
|
data_designer/engine/sampling_gen/data_sources/errors.py,sha256=5pq42e5yvUqaH-g09jWvJolYCO2I2Rdrqo1O0gwet8Y,326
|
|
149
149
|
data_designer/engine/sampling_gen/data_sources/sources.py,sha256=63YaRau37NIc2TDn8JvTOsd0zfnY4_aaF9UOU5ryKSo,13387
|
|
150
150
|
data_designer/engine/sampling_gen/entities/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
151
|
-
data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py,sha256
|
|
151
|
+
data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py,sha256=W_QSYNO2ynsXGJ71y_M9uRpYjjcbcAFhp1MpDFdl9YM,1844
|
|
152
152
|
data_designer/engine/sampling_gen/entities/email_address_utils.py,sha256=-V4zuuFq1t3nzzO_FqzCWApPcWNKAh-ZQYFMmCiu5RE,5231
|
|
153
153
|
data_designer/engine/sampling_gen/entities/errors.py,sha256=QEq-6Ld9OlModEYbse0pvY21OC5CyO-OalrL03-iLME,311
|
|
154
154
|
data_designer/engine/sampling_gen/entities/national_id_utils.py,sha256=vxxHnrfQP98W8dWGysCjvfIT-h1xEGdfxn5xF_-UeXw,2611
|
|
@@ -163,15 +163,15 @@ data_designer/engine/validators/remote.py,sha256=jtDIvWzfHh17m2ac_Fp93p49Th8RlkB
|
|
|
163
163
|
data_designer/engine/validators/sql.py,sha256=bxbyxPxDT9yuwjhABVEY40iR1pzWRFi65WU4tPgG2bE,2250
|
|
164
164
|
data_designer/essentials/__init__.py,sha256=zrDZ7hahOmOhCPdfoj0z9ALN10lXIesfwd2qXRqTcdY,4125
|
|
165
165
|
data_designer/interface/__init__.py,sha256=9eG4WHKyrJcNoK4GEz6BCw_E0Ewo9elQoDN4TLMbAog,137
|
|
166
|
-
data_designer/interface/data_designer.py,sha256=
|
|
166
|
+
data_designer/interface/data_designer.py,sha256=USPTruC5axBJNEWEnYBJ4ol2d3mXGubHELBmWeahFe8,16664
|
|
167
167
|
data_designer/interface/errors.py,sha256=jagKT3tPUnYq4e3e6AkTnBkcayHyEfxjPMBzx-GEKe4,565
|
|
168
168
|
data_designer/interface/results.py,sha256=qFxa8SuCXeADiRpaCMBwJcExkJBCfUPeGCdcJSTjoTc,2111
|
|
169
169
|
data_designer/plugins/__init__.py,sha256=c_V7q4QhfVoNf_uc9UwmXCsWqwtyWogI7YoN_0PzzE4,234
|
|
170
170
|
data_designer/plugins/errors.py,sha256=yPIHpSddEr-o9ZcNVibb2hI-73O15Kg_Od8SlmQlnRs,297
|
|
171
171
|
data_designer/plugins/plugin.py,sha256=7ErdUyrTdOb5PCBE3msdhTOrvQpldjOQw90-Bu4Bosc,2522
|
|
172
172
|
data_designer/plugins/registry.py,sha256=iPDTh4duV1cKt7H1fXkj1bKLG6SyUKmzQ9xh-vjEoaM,3018
|
|
173
|
-
data_designer-0.1.
|
|
174
|
-
data_designer-0.1.
|
|
175
|
-
data_designer-0.1.
|
|
176
|
-
data_designer-0.1.
|
|
177
|
-
data_designer-0.1.
|
|
173
|
+
data_designer-0.1.3.dist-info/METADATA,sha256=fCI36BVPIOC7FVxQviBmzWMX8HRnc69afkJ82xPYXbY,6644
|
|
174
|
+
data_designer-0.1.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
175
|
+
data_designer-0.1.3.dist-info/entry_points.txt,sha256=NWWWidyDxN6CYX6y664PhBYMhbaYTQTyprqfYAgkyCg,57
|
|
176
|
+
data_designer-0.1.3.dist-info/licenses/LICENSE,sha256=cSWJDwVqHyQgly8Zmt3pqXJ2eQbZVYwN9qd0NMssxXY,11336
|
|
177
|
+
data_designer-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|