data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +14 -1
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +5 -4
- data_designer/config/processors.py +109 -4
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +39 -9
- data_designer/config/utils/visualization.py +62 -15
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +14 -5
- data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/processors/drop_columns.py +1 -1
- data_designer/engine/processing/processors/registry.py +3 -0
- data_designer/engine/processing/processors/schema_transform.py +53 -0
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/interface/data_designer.py +12 -0
- data_designer/interface/results.py +36 -0
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal
|
|
6
6
|
|
|
7
7
|
import pandas as pd
|
|
8
8
|
from pydantic import Field, field_validator, model_validator
|
|
@@ -54,12 +54,12 @@ class CategorySamplerParams(ConfigBase):
|
|
|
54
54
|
Larger weights result in higher sampling probability for the corresponding value.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
values: list[
|
|
57
|
+
values: list[str | int | float] = Field(
|
|
58
58
|
...,
|
|
59
59
|
min_length=1,
|
|
60
60
|
description="List of possible categorical values that can be sampled from.",
|
|
61
61
|
)
|
|
62
|
-
weights:
|
|
62
|
+
weights: list[float] | None = Field(
|
|
63
63
|
default=None,
|
|
64
64
|
description=(
|
|
65
65
|
"List of unnormalized probability weights to assigned to each value, in order. "
|
|
@@ -134,7 +134,7 @@ class SubcategorySamplerParams(ConfigBase):
|
|
|
134
134
|
"""
|
|
135
135
|
|
|
136
136
|
category: str = Field(..., description="Name of parent category to this subcategory.")
|
|
137
|
-
values: dict[str, list[
|
|
137
|
+
values: dict[str, list[str | int | float]] = Field(
|
|
138
138
|
...,
|
|
139
139
|
description="Mapping from each value of parent category to a list of subcategory values.",
|
|
140
140
|
)
|
|
@@ -214,7 +214,7 @@ class UUIDSamplerParams(ConfigBase):
|
|
|
214
214
|
lowercase UUIDs.
|
|
215
215
|
"""
|
|
216
216
|
|
|
217
|
-
prefix:
|
|
217
|
+
prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
|
|
218
218
|
short_form: bool = Field(
|
|
219
219
|
default=False,
|
|
220
220
|
description="If true, all UUIDs sampled will be truncated at 8 characters.",
|
|
@@ -259,7 +259,7 @@ class ScipySamplerParams(ConfigBase):
|
|
|
259
259
|
...,
|
|
260
260
|
description="Parameters of the scipy.stats distribution given in `dist_name`.",
|
|
261
261
|
)
|
|
262
|
-
decimal_places:
|
|
262
|
+
decimal_places: int | None = Field(
|
|
263
263
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
264
264
|
)
|
|
265
265
|
sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
|
|
@@ -356,7 +356,7 @@ class GaussianSamplerParams(ConfigBase):
|
|
|
356
356
|
|
|
357
357
|
mean: float = Field(..., description="Mean of the Gaussian distribution")
|
|
358
358
|
stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
|
|
359
|
-
decimal_places:
|
|
359
|
+
decimal_places: int | None = Field(
|
|
360
360
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
361
361
|
)
|
|
362
362
|
sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
|
|
@@ -398,7 +398,7 @@ class UniformSamplerParams(ConfigBase):
|
|
|
398
398
|
|
|
399
399
|
low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
|
|
400
400
|
high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
|
|
401
|
-
decimal_places:
|
|
401
|
+
decimal_places: int | None = Field(
|
|
402
402
|
default=None, description="Number of decimal places to round the sampled values to."
|
|
403
403
|
)
|
|
404
404
|
sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
|
|
@@ -421,8 +421,8 @@ class PersonSamplerParams(ConfigBase):
|
|
|
421
421
|
|
|
422
422
|
Attributes:
|
|
423
423
|
locale: Locale string determining the language and geographic region for synthetic people.
|
|
424
|
-
|
|
425
|
-
|
|
424
|
+
Must be a locale supported by a managed Nemotron Personas dataset. The dataset must
|
|
425
|
+
be downloaded and available in the managed assets directory.
|
|
426
426
|
sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
|
|
427
427
|
"Female". If None, samples both sexes.
|
|
428
428
|
city: If specified, filters to only sample people from the specified city or cities. Can be
|
|
@@ -447,11 +447,11 @@ class PersonSamplerParams(ConfigBase):
|
|
|
447
447
|
f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
|
|
448
448
|
),
|
|
449
449
|
)
|
|
450
|
-
sex:
|
|
450
|
+
sex: SexT | None = Field(
|
|
451
451
|
default=None,
|
|
452
452
|
description="If specified, then only synthetic people of the specified sex will be sampled.",
|
|
453
453
|
)
|
|
454
|
-
city:
|
|
454
|
+
city: str | list[str] | None = Field(
|
|
455
455
|
default=None,
|
|
456
456
|
description="If specified, then only synthetic people from these cities will be sampled.",
|
|
457
457
|
)
|
|
@@ -461,7 +461,7 @@ class PersonSamplerParams(ConfigBase):
|
|
|
461
461
|
min_length=2,
|
|
462
462
|
max_length=2,
|
|
463
463
|
)
|
|
464
|
-
select_field_values:
|
|
464
|
+
select_field_values: dict[str, list[str]] | None = Field(
|
|
465
465
|
default=None,
|
|
466
466
|
description=(
|
|
467
467
|
"Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
|
|
@@ -529,11 +529,11 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
529
529
|
"that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
|
|
530
530
|
),
|
|
531
531
|
)
|
|
532
|
-
sex:
|
|
532
|
+
sex: SexT | None = Field(
|
|
533
533
|
default=None,
|
|
534
534
|
description="If specified, then only synthetic people of the specified sex will be sampled.",
|
|
535
535
|
)
|
|
536
|
-
city:
|
|
536
|
+
city: str | list[str] | None = Field(
|
|
537
537
|
default=None,
|
|
538
538
|
description="If specified, then only synthetic people from these cities will be sampled.",
|
|
539
539
|
)
|
|
@@ -585,22 +585,22 @@ class PersonFromFakerSamplerParams(ConfigBase):
|
|
|
585
585
|
return value
|
|
586
586
|
|
|
587
587
|
|
|
588
|
-
SamplerParamsT: TypeAlias =
|
|
589
|
-
SubcategorySamplerParams
|
|
590
|
-
CategorySamplerParams
|
|
591
|
-
DatetimeSamplerParams
|
|
592
|
-
PersonSamplerParams
|
|
593
|
-
PersonFromFakerSamplerParams
|
|
594
|
-
TimeDeltaSamplerParams
|
|
595
|
-
UUIDSamplerParams
|
|
596
|
-
BernoulliSamplerParams
|
|
597
|
-
BernoulliMixtureSamplerParams
|
|
598
|
-
BinomialSamplerParams
|
|
599
|
-
GaussianSamplerParams
|
|
600
|
-
PoissonSamplerParams
|
|
601
|
-
UniformSamplerParams
|
|
602
|
-
ScipySamplerParams
|
|
603
|
-
|
|
588
|
+
SamplerParamsT: TypeAlias = (
|
|
589
|
+
SubcategorySamplerParams
|
|
590
|
+
| CategorySamplerParams
|
|
591
|
+
| DatetimeSamplerParams
|
|
592
|
+
| PersonSamplerParams
|
|
593
|
+
| PersonFromFakerSamplerParams
|
|
594
|
+
| TimeDeltaSamplerParams
|
|
595
|
+
| UUIDSamplerParams
|
|
596
|
+
| BernoulliSamplerParams
|
|
597
|
+
| BernoulliMixtureSamplerParams
|
|
598
|
+
| BinomialSamplerParams
|
|
599
|
+
| GaussianSamplerParams
|
|
600
|
+
| PoissonSamplerParams
|
|
601
|
+
| UniformSamplerParams
|
|
602
|
+
| ScipySamplerParams
|
|
603
|
+
)
|
|
604
604
|
|
|
605
605
|
|
|
606
606
|
def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
|
data_designer/config/seed.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Optional, Union
|
|
7
6
|
|
|
8
7
|
from pydantic import Field, field_validator, model_validator
|
|
9
8
|
from typing_extensions import Self
|
|
@@ -112,7 +111,7 @@ class SeedConfig(ConfigBase):
|
|
|
112
111
|
|
|
113
112
|
dataset: str
|
|
114
113
|
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
|
|
115
|
-
selection_strategy:
|
|
114
|
+
selection_strategy: IndexRange | PartitionBlock | None = None
|
|
116
115
|
|
|
117
116
|
|
|
118
117
|
class SeedDatasetReference(ABC, ConfigBase):
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import Union
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class CodeLang(str, Enum):
|
|
@@ -26,17 +25,17 @@ class CodeLang(str, Enum):
|
|
|
26
25
|
SQL_ANSI = "sql:ansi"
|
|
27
26
|
|
|
28
27
|
@staticmethod
|
|
29
|
-
def parse(value:
|
|
28
|
+
def parse(value: str | CodeLang) -> tuple[str, str | None]:
|
|
30
29
|
value = value.value if isinstance(value, CodeLang) else value
|
|
31
30
|
split_vals = value.split(":")
|
|
32
31
|
return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
|
|
33
32
|
|
|
34
33
|
@staticmethod
|
|
35
|
-
def parse_lang(value:
|
|
34
|
+
def parse_lang(value: str | CodeLang) -> str:
|
|
36
35
|
return CodeLang.parse(value)[0]
|
|
37
36
|
|
|
38
37
|
@staticmethod
|
|
39
|
-
def parse_dialect(value:
|
|
38
|
+
def parse_dialect(value: str | CodeLang) -> str | None:
|
|
40
39
|
return CodeLang.parse(value)[1]
|
|
41
40
|
|
|
42
41
|
@staticmethod
|
|
@@ -58,7 +57,7 @@ SQL_DIALECTS: set[CodeLang] = {
|
|
|
58
57
|
##########################################################
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def code_lang_to_syntax_lexer(code_lang:
|
|
60
|
+
def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
|
|
62
61
|
"""Convert the code language to a syntax lexer for Pygments.
|
|
63
62
|
|
|
64
63
|
Reference: https://pygments.org/docs/lexers/
|
|
@@ -97,8 +97,6 @@ DEFAULT_AGE_RANGE = [18, 114]
|
|
|
97
97
|
MIN_AGE = 0
|
|
98
98
|
MAX_AGE = 114
|
|
99
99
|
|
|
100
|
-
LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"]
|
|
101
|
-
|
|
102
100
|
US_STATES_AND_MAJOR_TERRITORIES = {
|
|
103
101
|
# States
|
|
104
102
|
"AK",
|
|
@@ -299,15 +297,40 @@ PREDEFINED_PROVIDERS = [
|
|
|
299
297
|
},
|
|
300
298
|
]
|
|
301
299
|
|
|
300
|
+
|
|
301
|
+
DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
302
|
+
DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
|
|
303
|
+
DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
304
|
+
DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
|
|
305
|
+
|
|
306
|
+
|
|
302
307
|
PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
303
308
|
NVIDIA_PROVIDER_NAME: {
|
|
304
|
-
"text": "nvidia/
|
|
305
|
-
"reasoning": "openai/gpt-oss-20b",
|
|
306
|
-
"vision": "nvidia/nemotron-nano-12b-v2-vl",
|
|
309
|
+
"text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
|
|
310
|
+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
311
|
+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
312
|
+
"embedding": {
|
|
313
|
+
"model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
|
|
314
|
+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
|
|
315
|
+
},
|
|
307
316
|
},
|
|
308
317
|
OPENAI_PROVIDER_NAME: {
|
|
309
|
-
"text": "gpt-4.1",
|
|
310
|
-
"reasoning": "gpt-5",
|
|
311
|
-
"vision": "gpt-5",
|
|
318
|
+
"text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
|
|
319
|
+
"reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
320
|
+
"vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
321
|
+
"embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
|
|
312
322
|
},
|
|
313
323
|
}
|
|
324
|
+
|
|
325
|
+
# Persona locale metadata - used by the CLI and the person sampler.
|
|
326
|
+
NEMOTRON_PERSONAS_DATASET_SIZES = {
|
|
327
|
+
"en_US": "1.24 GB",
|
|
328
|
+
"en_IN": "2.39 GB",
|
|
329
|
+
"hi_Deva_IN": "4.14 GB",
|
|
330
|
+
"hi_Latn_IN": "2.7 GB",
|
|
331
|
+
"ja_JP": "1.69 GB",
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys())
|
|
335
|
+
|
|
336
|
+
NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-"
|
|
@@ -8,7 +8,7 @@ from datetime import date, datetime, timedelta
|
|
|
8
8
|
from decimal import Decimal
|
|
9
9
|
from numbers import Number
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -128,7 +128,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
|
|
|
128
128
|
dataframe.to_json(file_path, orient="records", lines=True)
|
|
129
129
|
|
|
130
130
|
|
|
131
|
-
def validate_dataset_file_path(file_path:
|
|
131
|
+
def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
|
|
132
132
|
"""Validate that a dataset file path has a valid extension and optionally exists.
|
|
133
133
|
|
|
134
134
|
Args:
|
|
@@ -165,7 +165,7 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str)
|
|
|
165
165
|
raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
|
|
166
166
|
|
|
167
167
|
|
|
168
|
-
def smart_load_dataframe(dataframe:
|
|
168
|
+
def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
|
|
169
169
|
"""Load a dataframe from file if a path is given, otherwise return the dataframe.
|
|
170
170
|
|
|
171
171
|
Args:
|
|
@@ -197,7 +197,7 @@ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFr
|
|
|
197
197
|
raise ValueError(f"Unsupported file format: {dataframe}")
|
|
198
198
|
|
|
199
199
|
|
|
200
|
-
def smart_load_yaml(yaml_in:
|
|
200
|
+
def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
|
|
201
201
|
"""Return the yaml config as a dict given flexible input types.
|
|
202
202
|
|
|
203
203
|
Args:
|
|
@@ -227,7 +227,7 @@ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
|
|
|
227
227
|
return yaml_out
|
|
228
228
|
|
|
229
229
|
|
|
230
|
-
def serialize_data(data:
|
|
230
|
+
def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
|
|
231
231
|
if isinstance(data, dict):
|
|
232
232
|
return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
|
|
233
233
|
elif isinstance(data, list):
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
from contextlib import contextmanager
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
from jinja2 import TemplateSyntaxError, meta
|
|
11
10
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
@@ -58,9 +57,7 @@ def get_prompt_template_keywords(template: str) -> set[str]:
|
|
|
58
57
|
return keywords
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def json_indent_list_of_strings(
|
|
62
|
-
column_names: list[str], *, indent: Optional[Union[int, str]] = None
|
|
63
|
-
) -> Optional[Union[list[str], str]]:
|
|
60
|
+
def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
|
|
64
61
|
"""Convert a list of column names to a JSON string if the list is long.
|
|
65
62
|
|
|
66
63
|
This function helps keep Data Designer's __repr__ output clean and readable.
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import numbers
|
|
5
5
|
from numbers import Number
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from data_designer.config.utils.constants import REPORTING_PRECISION
|
|
9
9
|
|
|
@@ -18,7 +18,7 @@ def is_float(val: Any) -> bool:
|
|
|
18
18
|
|
|
19
19
|
def prepare_number_for_reporting(
|
|
20
20
|
value: Number,
|
|
21
|
-
target_type:
|
|
21
|
+
target_type: type[Number],
|
|
22
22
|
precision: int = REPORTING_PRECISION,
|
|
23
23
|
) -> Number:
|
|
24
24
|
"""Ensure native python types and round to `precision` decimal digits."""
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import inspect
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Any, Literal,
|
|
6
|
+
from typing import Any, Literal, get_args, get_origin
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
@@ -56,7 +56,7 @@ def create_str_enum_from_discriminated_type_union(
|
|
|
56
56
|
return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def get_sampler_params() -> dict[str,
|
|
59
|
+
def get_sampler_params() -> dict[str, type[BaseModel]]:
|
|
60
60
|
"""Returns a dictionary of sampler parameter classes."""
|
|
61
61
|
params_cls_list = [
|
|
62
62
|
params_cls
|
|
@@ -83,7 +83,7 @@ def get_sampler_params() -> dict[str, Type[BaseModel]]:
|
|
|
83
83
|
return params_cls_dict
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
def resolve_string_enum(enum_instance: Any, enum_type:
|
|
86
|
+
def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
|
|
87
87
|
if not issubclass(enum_type, Enum):
|
|
88
88
|
raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
|
|
89
89
|
invalid_enum_value_error = InvalidEnumValueError(
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from string import Formatter
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
from jinja2 import meta
|
|
11
10
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
@@ -15,10 +14,13 @@ from rich.console import Console, Group
|
|
|
15
14
|
from rich.padding import Padding
|
|
16
15
|
from rich.panel import Panel
|
|
17
16
|
|
|
18
|
-
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType,
|
|
19
|
-
from data_designer.config.processors import
|
|
17
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
|
|
18
|
+
from data_designer.config.processors import ProcessorConfigT, ProcessorType
|
|
20
19
|
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
|
|
21
|
-
from data_designer.config.utils.misc import
|
|
20
|
+
from data_designer.config.utils.misc import (
|
|
21
|
+
can_run_data_designer_locally,
|
|
22
|
+
get_prompt_template_keywords,
|
|
23
|
+
)
|
|
22
24
|
from data_designer.config.validator_params import ValidatorType
|
|
23
25
|
|
|
24
26
|
|
|
@@ -42,7 +44,7 @@ class ViolationLevel(str, Enum):
|
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class Violation(BaseModel):
|
|
45
|
-
column:
|
|
47
|
+
column: str | None = None
|
|
46
48
|
type: ViolationType
|
|
47
49
|
message: str
|
|
48
50
|
level: ViolationLevel
|
|
@@ -54,7 +56,7 @@ class Violation(BaseModel):
|
|
|
54
56
|
|
|
55
57
|
def validate_data_designer_config(
|
|
56
58
|
columns: list[ColumnConfigT],
|
|
57
|
-
processor_configs: list[
|
|
59
|
+
processor_configs: list[ProcessorConfigT],
|
|
58
60
|
allowed_references: list[str],
|
|
59
61
|
) -> list[Violation]:
|
|
60
62
|
violations = []
|
|
@@ -63,6 +65,7 @@ def validate_data_designer_config(
|
|
|
63
65
|
violations.extend(validate_expression_references(columns=columns, allowed_references=allowed_references))
|
|
64
66
|
violations.extend(validate_columns_not_all_dropped(columns=columns))
|
|
65
67
|
violations.extend(validate_drop_columns_processor(columns=columns, processor_configs=processor_configs))
|
|
68
|
+
violations.extend(validate_schema_transform_processor(columns=columns, processor_configs=processor_configs))
|
|
66
69
|
if not can_run_data_designer_locally():
|
|
67
70
|
violations.extend(validate_local_only_columns(columns=columns))
|
|
68
71
|
return violations
|
|
@@ -115,7 +118,7 @@ def validate_prompt_templates(
|
|
|
115
118
|
) -> list[Violation]:
|
|
116
119
|
env = ImmutableSandboxedEnvironment()
|
|
117
120
|
|
|
118
|
-
columns_with_prompts = [c for c in columns if
|
|
121
|
+
columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)]
|
|
119
122
|
|
|
120
123
|
violations = []
|
|
121
124
|
for column in columns_with_prompts:
|
|
@@ -269,9 +272,9 @@ def validate_columns_not_all_dropped(
|
|
|
269
272
|
|
|
270
273
|
def validate_drop_columns_processor(
|
|
271
274
|
columns: list[ColumnConfigT],
|
|
272
|
-
processor_configs: list[
|
|
275
|
+
processor_configs: list[ProcessorConfigT],
|
|
273
276
|
) -> list[Violation]:
|
|
274
|
-
all_column_names =
|
|
277
|
+
all_column_names = {c.name for c in columns}
|
|
275
278
|
for processor_config in processor_configs:
|
|
276
279
|
if processor_config.processor_type == ProcessorType.DROP_COLUMNS:
|
|
277
280
|
invalid_columns = set(processor_config.column_names) - all_column_names
|
|
@@ -288,6 +291,33 @@ def validate_drop_columns_processor(
|
|
|
288
291
|
return []
|
|
289
292
|
|
|
290
293
|
|
|
294
|
+
def validate_schema_transform_processor(
|
|
295
|
+
columns: list[ColumnConfigT],
|
|
296
|
+
processor_configs: list[ProcessorConfigT],
|
|
297
|
+
) -> list[Violation]:
|
|
298
|
+
violations = []
|
|
299
|
+
|
|
300
|
+
all_column_names = {c.name for c in columns}
|
|
301
|
+
for processor_config in processor_configs:
|
|
302
|
+
if processor_config.processor_type == ProcessorType.SCHEMA_TRANSFORM:
|
|
303
|
+
for col, template in processor_config.template.items():
|
|
304
|
+
template_keywords = get_prompt_template_keywords(template)
|
|
305
|
+
invalid_keywords = set(template_keywords) - all_column_names
|
|
306
|
+
if len(invalid_keywords) > 0:
|
|
307
|
+
invalid_keywords = ", ".join([f"'{k}'" for k in invalid_keywords])
|
|
308
|
+
message = f"Ancillary dataset processor attempts to reference columns {invalid_keywords} in the template for '{col}', but the columns are not defined in the dataset."
|
|
309
|
+
violations.append(
|
|
310
|
+
Violation(
|
|
311
|
+
column=None,
|
|
312
|
+
type=ViolationType.INVALID_REFERENCE,
|
|
313
|
+
message=message,
|
|
314
|
+
level=ViolationLevel.ERROR,
|
|
315
|
+
)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return violations
|
|
319
|
+
|
|
320
|
+
|
|
291
321
|
def validate_expression_references(
|
|
292
322
|
columns: list[ColumnConfigT],
|
|
293
323
|
allowed_references: list[str],
|
|
@@ -8,7 +8,7 @@ import os
|
|
|
8
8
|
from collections import OrderedDict
|
|
9
9
|
from enum import Enum
|
|
10
10
|
from functools import cached_property
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -36,11 +36,11 @@ if TYPE_CHECKING:
|
|
|
36
36
|
console = Console()
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def get_nvidia_api_key() ->
|
|
39
|
+
def get_nvidia_api_key() -> str | None:
|
|
40
40
|
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def get_openai_api_key() ->
|
|
43
|
+
def get_openai_api_key() -> str | None:
|
|
44
44
|
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
|
|
45
45
|
|
|
46
46
|
|
|
@@ -72,13 +72,17 @@ class WithRecordSamplerMixin:
|
|
|
72
72
|
else:
|
|
73
73
|
raise DatasetSampleDisplayError("No valid dataset found in results object.")
|
|
74
74
|
|
|
75
|
+
def _has_processor_artifacts(self) -> bool:
|
|
76
|
+
return hasattr(self, "processor_artifacts") and self.processor_artifacts is not None
|
|
77
|
+
|
|
75
78
|
def display_sample_record(
|
|
76
79
|
self,
|
|
77
|
-
index:
|
|
80
|
+
index: int | None = None,
|
|
78
81
|
*,
|
|
79
82
|
hide_seed_columns: bool = False,
|
|
80
83
|
syntax_highlighting_theme: str = "dracula",
|
|
81
|
-
background_color:
|
|
84
|
+
background_color: str | None = None,
|
|
85
|
+
processors_to_display: list[str] | None = None,
|
|
82
86
|
) -> None:
|
|
83
87
|
"""Display a sample record from the Data Designer dataset preview.
|
|
84
88
|
|
|
@@ -90,6 +94,7 @@ class WithRecordSamplerMixin:
|
|
|
90
94
|
documentation from `rich` for information about available themes.
|
|
91
95
|
background_color: Background color to use for the record. See the `Syntax`
|
|
92
96
|
documentation from `rich` for information about available background colors.
|
|
97
|
+
processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed.
|
|
93
98
|
"""
|
|
94
99
|
i = index or self._display_cycle_index
|
|
95
100
|
|
|
@@ -99,8 +104,25 @@ class WithRecordSamplerMixin:
|
|
|
99
104
|
except IndexError:
|
|
100
105
|
raise DatasetSampleDisplayError(f"Index {i} is out of bounds for dataset of length {num_records}.")
|
|
101
106
|
|
|
107
|
+
processor_data_to_display = None
|
|
108
|
+
if self._has_processor_artifacts() and len(self.processor_artifacts) > 0:
|
|
109
|
+
if processors_to_display is None:
|
|
110
|
+
processors_to_display = list(self.processor_artifacts.keys())
|
|
111
|
+
|
|
112
|
+
if len(processors_to_display) > 0:
|
|
113
|
+
processor_data_to_display = {}
|
|
114
|
+
for processor in processors_to_display:
|
|
115
|
+
if (
|
|
116
|
+
isinstance(self.processor_artifacts[processor], list)
|
|
117
|
+
and len(self.processor_artifacts[processor]) == num_records
|
|
118
|
+
):
|
|
119
|
+
processor_data_to_display[processor] = self.processor_artifacts[processor][i]
|
|
120
|
+
else:
|
|
121
|
+
processor_data_to_display[processor] = self.processor_artifacts[processor]
|
|
122
|
+
|
|
102
123
|
display_sample_record(
|
|
103
124
|
record=record,
|
|
125
|
+
processor_data_to_display=processor_data_to_display,
|
|
104
126
|
config_builder=self._config_builder,
|
|
105
127
|
background_color=background_color,
|
|
106
128
|
syntax_highlighting_theme=syntax_highlighting_theme,
|
|
@@ -112,11 +134,11 @@ class WithRecordSamplerMixin:
|
|
|
112
134
|
|
|
113
135
|
|
|
114
136
|
def create_rich_histogram_table(
|
|
115
|
-
data: dict[str,
|
|
137
|
+
data: dict[str, int | float],
|
|
116
138
|
column_names: tuple[int, int],
|
|
117
139
|
name_style: str = ColorPalette.BLUE.value,
|
|
118
140
|
value_style: str = ColorPalette.TEAL.value,
|
|
119
|
-
title:
|
|
141
|
+
title: str | None = None,
|
|
120
142
|
**kwargs,
|
|
121
143
|
) -> Table:
|
|
122
144
|
table = Table(title=title, **kwargs)
|
|
@@ -132,11 +154,12 @@ def create_rich_histogram_table(
|
|
|
132
154
|
|
|
133
155
|
|
|
134
156
|
def display_sample_record(
|
|
135
|
-
record:
|
|
157
|
+
record: dict | pd.Series | pd.DataFrame,
|
|
136
158
|
config_builder: DataDesignerConfigBuilder,
|
|
137
|
-
|
|
159
|
+
processor_data_to_display: dict[str, list[str] | str] | None = None,
|
|
160
|
+
background_color: str | None = None,
|
|
138
161
|
syntax_highlighting_theme: str = "dracula",
|
|
139
|
-
record_index:
|
|
162
|
+
record_index: int | None = None,
|
|
140
163
|
hide_seed_columns: bool = False,
|
|
141
164
|
):
|
|
142
165
|
if isinstance(record, (dict, pd.Series)):
|
|
@@ -171,6 +194,7 @@ def display_sample_record(
|
|
|
171
194
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
|
|
172
195
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
|
|
173
196
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
|
|
197
|
+
+ config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
|
|
174
198
|
)
|
|
175
199
|
if len(non_code_columns) > 0:
|
|
176
200
|
table = Table(title="Generated Columns", **table_kws)
|
|
@@ -178,6 +202,10 @@ def display_sample_record(
|
|
|
178
202
|
table.add_column("Value")
|
|
179
203
|
for col in non_code_columns:
|
|
180
204
|
if not col.drop:
|
|
205
|
+
if col.column_type == DataDesignerColumnType.EMBEDDING:
|
|
206
|
+
record[col.name]["embeddings"] = [
|
|
207
|
+
get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
|
|
208
|
+
]
|
|
181
209
|
table.add_row(col.name, convert_to_row_element(record[col.name]))
|
|
182
210
|
render_list.append(pad_console_element(table))
|
|
183
211
|
|
|
@@ -230,6 +258,15 @@ def display_sample_record(
|
|
|
230
258
|
table.add_row(*row)
|
|
231
259
|
render_list.append(pad_console_element(table, (1, 0, 1, 0)))
|
|
232
260
|
|
|
261
|
+
if processor_data_to_display and len(processor_data_to_display) > 0:
|
|
262
|
+
for processor_name, processor_data in processor_data_to_display.items():
|
|
263
|
+
table = Table(title=f"Processor Outputs: {processor_name}", **table_kws)
|
|
264
|
+
table.add_column("Name")
|
|
265
|
+
table.add_column("Value")
|
|
266
|
+
for col, value in processor_data.items():
|
|
267
|
+
table.add_row(col, convert_to_row_element(value))
|
|
268
|
+
render_list.append(pad_console_element(table, (1, 0, 1, 0)))
|
|
269
|
+
|
|
233
270
|
if record_index is not None:
|
|
234
271
|
index_label = Text(f"[index: {record_index}]", justify="center")
|
|
235
272
|
render_list.append(index_label)
|
|
@@ -237,9 +274,19 @@ def display_sample_record(
|
|
|
237
274
|
console.print(Group(*render_list), markup=False)
|
|
238
275
|
|
|
239
276
|
|
|
277
|
+
def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
|
|
278
|
+
if max_items <= 0:
|
|
279
|
+
raise ValueError("max_items must be greater than 0")
|
|
280
|
+
if len(long_list) > max_items:
|
|
281
|
+
truncated_part = long_list[:max_items]
|
|
282
|
+
return f"[{', '.join(str(x) for x in truncated_part)}, ...]"
|
|
283
|
+
else:
|
|
284
|
+
return str(long_list)
|
|
285
|
+
|
|
286
|
+
|
|
240
287
|
def display_sampler_table(
|
|
241
288
|
sampler_params: dict[SamplerType, ConfigBase],
|
|
242
|
-
title:
|
|
289
|
+
title: str | None = None,
|
|
243
290
|
) -> None:
|
|
244
291
|
table = Table(expand=True)
|
|
245
292
|
table.add_column("Type")
|
|
@@ -274,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
|
|
|
274
321
|
table_model_configs.add_column("Alias")
|
|
275
322
|
table_model_configs.add_column("Model")
|
|
276
323
|
table_model_configs.add_column("Provider")
|
|
277
|
-
table_model_configs.add_column("
|
|
278
|
-
table_model_configs.add_column("Top P")
|
|
324
|
+
table_model_configs.add_column("Inference Parameters")
|
|
279
325
|
for model_config in model_configs:
|
|
326
|
+
params_display = model_config.inference_parameters.format_for_display()
|
|
327
|
+
|
|
280
328
|
table_model_configs.add_row(
|
|
281
329
|
model_config.alias,
|
|
282
330
|
model_config.model,
|
|
283
331
|
model_config.provider,
|
|
284
|
-
|
|
285
|
-
str(model_config.inference_parameters.top_p),
|
|
332
|
+
params_display,
|
|
286
333
|
)
|
|
287
334
|
group_args: list = [Rule(title="Model Configs"), table_model_configs]
|
|
288
335
|
if len(model_configs) == 0:
|