data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import Union
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class CodeLang(str, Enum):
|
|
@@ -26,17 +25,17 @@ class CodeLang(str, Enum):
|
|
|
26
25
|
SQL_ANSI = "sql:ansi"
|
|
27
26
|
|
|
28
27
|
@staticmethod
|
|
29
|
-
def parse(value:
|
|
28
|
+
def parse(value: str | CodeLang) -> tuple[str, str | None]:
|
|
30
29
|
value = value.value if isinstance(value, CodeLang) else value
|
|
31
30
|
split_vals = value.split(":")
|
|
32
31
|
return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
|
|
33
32
|
|
|
34
33
|
@staticmethod
|
|
35
|
-
def parse_lang(value:
|
|
34
|
+
def parse_lang(value: str | CodeLang) -> str:
|
|
36
35
|
return CodeLang.parse(value)[0]
|
|
37
36
|
|
|
38
37
|
@staticmethod
|
|
39
|
-
def parse_dialect(value:
|
|
38
|
+
def parse_dialect(value: str | CodeLang) -> str | None:
|
|
40
39
|
return CodeLang.parse(value)[1]
|
|
41
40
|
|
|
42
41
|
@staticmethod
|
|
@@ -58,7 +57,7 @@ SQL_DIALECTS: set[CodeLang] = {
|
|
|
58
57
|
##########################################################
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def code_lang_to_syntax_lexer(code_lang:
|
|
60
|
+
def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
|
|
62
61
|
"""Convert the code language to a syntax lexer for Pygments.
|
|
63
62
|
|
|
64
63
|
Reference: https://pygments.org/docs/lexers/
|
|
@@ -97,8 +97,6 @@ DEFAULT_AGE_RANGE = [18, 114]
|
|
|
97
97
|
MIN_AGE = 0
|
|
98
98
|
MAX_AGE = 114
|
|
99
99
|
|
|
100
|
-
LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"]
|
|
101
|
-
|
|
102
100
|
US_STATES_AND_MAJOR_TERRITORIES = {
|
|
103
101
|
# States
|
|
104
102
|
"AK",
|
|
@@ -299,15 +297,40 @@ PREDEFINED_PROVIDERS = [
|
|
|
299
297
|
},
|
|
300
298
|
]
|
|
301
299
|
|
|
300
|
+
|
|
301
|
+
DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
302
|
+
DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
|
|
303
|
+
DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
|
|
304
|
+
DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
|
|
305
|
+
|
|
306
|
+
|
|
302
307
|
PREDEFINED_PROVIDERS_MODEL_MAP = {
|
|
303
308
|
NVIDIA_PROVIDER_NAME: {
|
|
304
|
-
"text": "nvidia/
|
|
305
|
-
"reasoning": "openai/gpt-oss-20b",
|
|
306
|
-
"vision": "nvidia/nemotron-nano-12b-v2-vl",
|
|
309
|
+
"text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
|
|
310
|
+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
311
|
+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
312
|
+
"embedding": {
|
|
313
|
+
"model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
|
|
314
|
+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
|
|
315
|
+
},
|
|
307
316
|
},
|
|
308
317
|
OPENAI_PROVIDER_NAME: {
|
|
309
|
-
"text": "gpt-4.1",
|
|
310
|
-
"reasoning": "gpt-5",
|
|
311
|
-
"vision": "gpt-5",
|
|
318
|
+
"text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
|
|
319
|
+
"reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
|
|
320
|
+
"vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
|
|
321
|
+
"embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
|
|
312
322
|
},
|
|
313
323
|
}
|
|
324
|
+
|
|
325
|
+
# Persona locale metadata - used by the CLI and the person sampler.
|
|
326
|
+
NEMOTRON_PERSONAS_DATASET_SIZES = {
|
|
327
|
+
"en_US": "1.24 GB",
|
|
328
|
+
"en_IN": "2.39 GB",
|
|
329
|
+
"hi_Deva_IN": "4.14 GB",
|
|
330
|
+
"hi_Latn_IN": "2.7 GB",
|
|
331
|
+
"ja_JP": "1.69 GB",
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys())
|
|
335
|
+
|
|
336
|
+
NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-"
|
|
@@ -8,7 +8,7 @@ from datetime import date, datetime, timedelta
|
|
|
8
8
|
from decimal import Decimal
|
|
9
9
|
from numbers import Number
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -128,7 +128,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
|
|
|
128
128
|
dataframe.to_json(file_path, orient="records", lines=True)
|
|
129
129
|
|
|
130
130
|
|
|
131
|
-
def validate_dataset_file_path(file_path:
|
|
131
|
+
def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
|
|
132
132
|
"""Validate that a dataset file path has a valid extension and optionally exists.
|
|
133
133
|
|
|
134
134
|
Args:
|
|
@@ -165,7 +165,7 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str)
|
|
|
165
165
|
raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
|
|
166
166
|
|
|
167
167
|
|
|
168
|
-
def smart_load_dataframe(dataframe:
|
|
168
|
+
def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
|
|
169
169
|
"""Load a dataframe from file if a path is given, otherwise return the dataframe.
|
|
170
170
|
|
|
171
171
|
Args:
|
|
@@ -197,7 +197,7 @@ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFr
|
|
|
197
197
|
raise ValueError(f"Unsupported file format: {dataframe}")
|
|
198
198
|
|
|
199
199
|
|
|
200
|
-
def smart_load_yaml(yaml_in:
|
|
200
|
+
def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
|
|
201
201
|
"""Return the yaml config as a dict given flexible input types.
|
|
202
202
|
|
|
203
203
|
Args:
|
|
@@ -227,7 +227,7 @@ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
|
|
|
227
227
|
return yaml_out
|
|
228
228
|
|
|
229
229
|
|
|
230
|
-
def serialize_data(data:
|
|
230
|
+
def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
|
|
231
231
|
if isinstance(data, dict):
|
|
232
232
|
return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
|
|
233
233
|
elif isinstance(data, list):
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
from contextlib import contextmanager
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
from jinja2 import TemplateSyntaxError, meta
|
|
11
10
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
@@ -58,9 +57,7 @@ def get_prompt_template_keywords(template: str) -> set[str]:
|
|
|
58
57
|
return keywords
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def json_indent_list_of_strings(
|
|
62
|
-
column_names: list[str], *, indent: Optional[Union[int, str]] = None
|
|
63
|
-
) -> Optional[Union[list[str], str]]:
|
|
60
|
+
def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
|
|
64
61
|
"""Convert a list of column names to a JSON string if the list is long.
|
|
65
62
|
|
|
66
63
|
This function helps keep Data Designer's __repr__ output clean and readable.
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import numbers
|
|
5
5
|
from numbers import Number
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from data_designer.config.utils.constants import REPORTING_PRECISION
|
|
9
9
|
|
|
@@ -18,7 +18,7 @@ def is_float(val: Any) -> bool:
|
|
|
18
18
|
|
|
19
19
|
def prepare_number_for_reporting(
|
|
20
20
|
value: Number,
|
|
21
|
-
target_type:
|
|
21
|
+
target_type: type[Number],
|
|
22
22
|
precision: int = REPORTING_PRECISION,
|
|
23
23
|
) -> Number:
|
|
24
24
|
"""Ensure native python types and round to `precision` decimal digits."""
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import inspect
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Any, Literal,
|
|
6
|
+
from typing import Any, Literal, get_args, get_origin
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
@@ -56,7 +56,7 @@ def create_str_enum_from_discriminated_type_union(
|
|
|
56
56
|
return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def get_sampler_params() -> dict[str,
|
|
59
|
+
def get_sampler_params() -> dict[str, type[BaseModel]]:
|
|
60
60
|
"""Returns a dictionary of sampler parameter classes."""
|
|
61
61
|
params_cls_list = [
|
|
62
62
|
params_cls
|
|
@@ -83,7 +83,7 @@ def get_sampler_params() -> dict[str, Type[BaseModel]]:
|
|
|
83
83
|
return params_cls_dict
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
def resolve_string_enum(enum_instance: Any, enum_type:
|
|
86
|
+
def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
|
|
87
87
|
if not issubclass(enum_type, Enum):
|
|
88
88
|
raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
|
|
89
89
|
invalid_enum_value_error = InvalidEnumValueError(
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from string import Formatter
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
from jinja2 import meta
|
|
11
10
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
@@ -15,8 +14,8 @@ from rich.console import Console, Group
|
|
|
15
14
|
from rich.padding import Padding
|
|
16
15
|
from rich.panel import Panel
|
|
17
16
|
|
|
18
|
-
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType,
|
|
19
|
-
from data_designer.config.processors import
|
|
17
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
|
|
18
|
+
from data_designer.config.processors import ProcessorConfigT, ProcessorType
|
|
20
19
|
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
|
|
21
20
|
from data_designer.config.utils.misc import (
|
|
22
21
|
can_run_data_designer_locally,
|
|
@@ -45,7 +44,7 @@ class ViolationLevel(str, Enum):
|
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
class Violation(BaseModel):
|
|
48
|
-
column:
|
|
47
|
+
column: str | None = None
|
|
49
48
|
type: ViolationType
|
|
50
49
|
message: str
|
|
51
50
|
level: ViolationLevel
|
|
@@ -57,7 +56,7 @@ class Violation(BaseModel):
|
|
|
57
56
|
|
|
58
57
|
def validate_data_designer_config(
|
|
59
58
|
columns: list[ColumnConfigT],
|
|
60
|
-
processor_configs: list[
|
|
59
|
+
processor_configs: list[ProcessorConfigT],
|
|
61
60
|
allowed_references: list[str],
|
|
62
61
|
) -> list[Violation]:
|
|
63
62
|
violations = []
|
|
@@ -119,7 +118,7 @@ def validate_prompt_templates(
|
|
|
119
118
|
) -> list[Violation]:
|
|
120
119
|
env = ImmutableSandboxedEnvironment()
|
|
121
120
|
|
|
122
|
-
columns_with_prompts = [c for c in columns if
|
|
121
|
+
columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)]
|
|
123
122
|
|
|
124
123
|
violations = []
|
|
125
124
|
for column in columns_with_prompts:
|
|
@@ -273,7 +272,7 @@ def validate_columns_not_all_dropped(
|
|
|
273
272
|
|
|
274
273
|
def validate_drop_columns_processor(
|
|
275
274
|
columns: list[ColumnConfigT],
|
|
276
|
-
processor_configs: list[
|
|
275
|
+
processor_configs: list[ProcessorConfigT],
|
|
277
276
|
) -> list[Violation]:
|
|
278
277
|
all_column_names = {c.name for c in columns}
|
|
279
278
|
for processor_config in processor_configs:
|
|
@@ -294,7 +293,7 @@ def validate_drop_columns_processor(
|
|
|
294
293
|
|
|
295
294
|
def validate_schema_transform_processor(
|
|
296
295
|
columns: list[ColumnConfigT],
|
|
297
|
-
processor_configs: list[
|
|
296
|
+
processor_configs: list[ProcessorConfigT],
|
|
298
297
|
) -> list[Violation]:
|
|
299
298
|
violations = []
|
|
300
299
|
|
|
@@ -8,7 +8,7 @@ import os
|
|
|
8
8
|
from collections import OrderedDict
|
|
9
9
|
from enum import Enum
|
|
10
10
|
from functools import cached_property
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -36,11 +36,11 @@ if TYPE_CHECKING:
|
|
|
36
36
|
console = Console()
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def get_nvidia_api_key() ->
|
|
39
|
+
def get_nvidia_api_key() -> str | None:
|
|
40
40
|
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def get_openai_api_key() ->
|
|
43
|
+
def get_openai_api_key() -> str | None:
|
|
44
44
|
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
|
|
45
45
|
|
|
46
46
|
|
|
@@ -77,12 +77,12 @@ class WithRecordSamplerMixin:
|
|
|
77
77
|
|
|
78
78
|
def display_sample_record(
|
|
79
79
|
self,
|
|
80
|
-
index:
|
|
80
|
+
index: int | None = None,
|
|
81
81
|
*,
|
|
82
82
|
hide_seed_columns: bool = False,
|
|
83
83
|
syntax_highlighting_theme: str = "dracula",
|
|
84
|
-
background_color:
|
|
85
|
-
processors_to_display:
|
|
84
|
+
background_color: str | None = None,
|
|
85
|
+
processors_to_display: list[str] | None = None,
|
|
86
86
|
) -> None:
|
|
87
87
|
"""Display a sample record from the Data Designer dataset preview.
|
|
88
88
|
|
|
@@ -134,11 +134,11 @@ class WithRecordSamplerMixin:
|
|
|
134
134
|
|
|
135
135
|
|
|
136
136
|
def create_rich_histogram_table(
|
|
137
|
-
data: dict[str,
|
|
137
|
+
data: dict[str, int | float],
|
|
138
138
|
column_names: tuple[int, int],
|
|
139
139
|
name_style: str = ColorPalette.BLUE.value,
|
|
140
140
|
value_style: str = ColorPalette.TEAL.value,
|
|
141
|
-
title:
|
|
141
|
+
title: str | None = None,
|
|
142
142
|
**kwargs,
|
|
143
143
|
) -> Table:
|
|
144
144
|
table = Table(title=title, **kwargs)
|
|
@@ -154,12 +154,12 @@ def create_rich_histogram_table(
|
|
|
154
154
|
|
|
155
155
|
|
|
156
156
|
def display_sample_record(
|
|
157
|
-
record:
|
|
157
|
+
record: dict | pd.Series | pd.DataFrame,
|
|
158
158
|
config_builder: DataDesignerConfigBuilder,
|
|
159
|
-
processor_data_to_display:
|
|
160
|
-
background_color:
|
|
159
|
+
processor_data_to_display: dict[str, list[str] | str] | None = None,
|
|
160
|
+
background_color: str | None = None,
|
|
161
161
|
syntax_highlighting_theme: str = "dracula",
|
|
162
|
-
record_index:
|
|
162
|
+
record_index: int | None = None,
|
|
163
163
|
hide_seed_columns: bool = False,
|
|
164
164
|
):
|
|
165
165
|
if isinstance(record, (dict, pd.Series)):
|
|
@@ -194,6 +194,7 @@ def display_sample_record(
|
|
|
194
194
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
|
|
195
195
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
|
|
196
196
|
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
|
|
197
|
+
+ config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
|
|
197
198
|
)
|
|
198
199
|
if len(non_code_columns) > 0:
|
|
199
200
|
table = Table(title="Generated Columns", **table_kws)
|
|
@@ -201,6 +202,10 @@ def display_sample_record(
|
|
|
201
202
|
table.add_column("Value")
|
|
202
203
|
for col in non_code_columns:
|
|
203
204
|
if not col.drop:
|
|
205
|
+
if col.column_type == DataDesignerColumnType.EMBEDDING:
|
|
206
|
+
record[col.name]["embeddings"] = [
|
|
207
|
+
get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
|
|
208
|
+
]
|
|
204
209
|
table.add_row(col.name, convert_to_row_element(record[col.name]))
|
|
205
210
|
render_list.append(pad_console_element(table))
|
|
206
211
|
|
|
@@ -269,9 +274,19 @@ def display_sample_record(
|
|
|
269
274
|
console.print(Group(*render_list), markup=False)
|
|
270
275
|
|
|
271
276
|
|
|
277
|
+
def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
|
|
278
|
+
if max_items <= 0:
|
|
279
|
+
raise ValueError("max_items must be greater than 0")
|
|
280
|
+
if len(long_list) > max_items:
|
|
281
|
+
truncated_part = long_list[:max_items]
|
|
282
|
+
return f"[{', '.join(str(x) for x in truncated_part)}, ...]"
|
|
283
|
+
else:
|
|
284
|
+
return str(long_list)
|
|
285
|
+
|
|
286
|
+
|
|
272
287
|
def display_sampler_table(
|
|
273
288
|
sampler_params: dict[SamplerType, ConfigBase],
|
|
274
|
-
title:
|
|
289
|
+
title: str | None = None,
|
|
275
290
|
) -> None:
|
|
276
291
|
table = Table(expand=True)
|
|
277
292
|
table.add_column("Type")
|
|
@@ -306,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
|
|
|
306
321
|
table_model_configs.add_column("Alias")
|
|
307
322
|
table_model_configs.add_column("Model")
|
|
308
323
|
table_model_configs.add_column("Provider")
|
|
309
|
-
table_model_configs.add_column("
|
|
310
|
-
table_model_configs.add_column("Top P")
|
|
324
|
+
table_model_configs.add_column("Inference Parameters")
|
|
311
325
|
for model_config in model_configs:
|
|
326
|
+
params_display = model_config.inference_parameters.format_for_display()
|
|
327
|
+
|
|
312
328
|
table_model_configs.add_row(
|
|
313
329
|
model_config.alias,
|
|
314
330
|
model_config.model,
|
|
315
331
|
model_config.provider,
|
|
316
|
-
|
|
317
|
-
str(model_config.inference_parameters.top_p),
|
|
332
|
+
params_display,
|
|
318
333
|
)
|
|
319
334
|
group_args: list = [Rule(title="Model Configs"), table_model_configs]
|
|
320
335
|
if len(model_configs) == 0:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
6
6
|
|
|
7
7
|
from pydantic import Field, field_serializer, model_validator
|
|
8
8
|
from typing_extensions import Self, TypeAlias
|
|
@@ -51,7 +51,7 @@ class LocalCallableValidatorParams(ConfigBase):
|
|
|
51
51
|
validation_function: Any = Field(
|
|
52
52
|
description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
|
|
53
53
|
)
|
|
54
|
-
output_schema:
|
|
54
|
+
output_schema: dict[str, Any] | None = Field(
|
|
55
55
|
default=None, description="Expected schema for local callable validator's output"
|
|
56
56
|
)
|
|
57
57
|
|
|
@@ -80,7 +80,7 @@ class RemoteValidatorParams(ConfigBase):
|
|
|
80
80
|
"""
|
|
81
81
|
|
|
82
82
|
endpoint_url: str = Field(description="URL of the remote endpoint")
|
|
83
|
-
output_schema:
|
|
83
|
+
output_schema: dict[str, Any] | None = Field(
|
|
84
84
|
default=None, description="Expected schema for remote validator's output"
|
|
85
85
|
)
|
|
86
86
|
timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request")
|
|
@@ -89,8 +89,4 @@ class RemoteValidatorParams(ConfigBase):
|
|
|
89
89
|
max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make")
|
|
90
90
|
|
|
91
91
|
|
|
92
|
-
ValidatorParamsT: TypeAlias =
|
|
93
|
-
CodeValidatorParams,
|
|
94
|
-
LocalCallableValidatorParams,
|
|
95
|
-
RemoteValidatorParams,
|
|
96
|
-
]
|
|
92
|
+
ValidatorParamsT: TypeAlias = CodeValidatorParams | LocalCallableValidatorParams | RemoteValidatorParams
|
|
@@ -7,7 +7,6 @@ import logging
|
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
|
|
9
9
|
import pandas as pd
|
|
10
|
-
import pyarrow as pa
|
|
11
10
|
from pydantic import BaseModel, model_validator
|
|
12
11
|
from typing_extensions import Self
|
|
13
12
|
|
|
@@ -29,12 +28,6 @@ class ColumnConfigWithDataFrame(ConfigBase):
|
|
|
29
28
|
raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
|
|
30
29
|
return self
|
|
31
30
|
|
|
32
|
-
@model_validator(mode="after")
|
|
33
|
-
def ensure_pyarrow_backend(self) -> Self:
|
|
34
|
-
if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
|
|
35
|
-
self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
|
|
36
|
-
return self
|
|
37
|
-
|
|
38
31
|
def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
|
|
39
32
|
return (self.column_config, self.df)
|
|
40
33
|
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
import random
|
|
8
|
-
from typing import Union
|
|
9
8
|
|
|
10
9
|
from data_designer.config.analysis.column_profilers import (
|
|
11
10
|
JudgeScoreProfilerConfig,
|
|
@@ -69,7 +68,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
|
|
|
69
68
|
)
|
|
70
69
|
|
|
71
70
|
for score in column_config.scores:
|
|
72
|
-
score_name = score.name
|
|
71
|
+
score_name = score.name
|
|
73
72
|
logger.info(f"{random.choice(['👩⚖️', '👨⚖️'])} Summarizing LLM-as-judge score: '{score_name}'")
|
|
74
73
|
score_sample = sample_scores_and_reasoning(
|
|
75
74
|
scores=score_distributions.scores[score_name],
|
|
@@ -96,7 +95,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
|
|
|
96
95
|
name: str,
|
|
97
96
|
sample: list[JudgeScoreSample],
|
|
98
97
|
histogram: CategoricalHistogramData,
|
|
99
|
-
distribution:
|
|
98
|
+
distribution: CategoricalDistribution | NumericalDistribution | MissingValue,
|
|
100
99
|
distribution_type: ColumnDistributionType,
|
|
101
100
|
) -> JudgeScoreSummary:
|
|
102
101
|
if isinstance(distribution, MissingValue) or not sample:
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, TypeAlias
|
|
8
8
|
|
|
9
9
|
import pandas as pd
|
|
10
10
|
from pydantic import BaseModel
|
|
@@ -41,7 +41,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
|
|
|
41
41
|
return self.column_config_with_df.df
|
|
42
42
|
|
|
43
43
|
@property
|
|
44
|
-
def column_statistics_type(self) ->
|
|
44
|
+
def column_statistics_type(self) -> type[ColumnStatisticsT]:
|
|
45
45
|
return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics)
|
|
46
46
|
|
|
47
47
|
def calculate(self) -> Self:
|
|
@@ -59,7 +59,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
|
|
|
59
59
|
)
|
|
60
60
|
|
|
61
61
|
def calculate_general_column_info(self) -> dict[str, Any]:
|
|
62
|
-
return calculate_general_column_info(self.column_config, self.df)
|
|
62
|
+
return calculate_general_column_info(self.column_config.name, self.df)
|
|
63
63
|
|
|
64
64
|
def __repr__(self) -> str:
|
|
65
65
|
params = []
|
|
@@ -93,7 +93,7 @@ class SamplerColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
|
93
93
|
return (
|
|
94
94
|
{
|
|
95
95
|
"sampler_type": SamplerType(self.column_config.sampler_type),
|
|
96
|
-
**calculate_column_distribution(self.column_config, self.df, dist_type),
|
|
96
|
+
**calculate_column_distribution(self.column_config.name, self.df, dist_type),
|
|
97
97
|
}
|
|
98
98
|
if make_dist
|
|
99
99
|
else {
|
|
@@ -109,23 +109,23 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
|
109
109
|
|
|
110
110
|
class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
111
111
|
def calculate_validation_column_info(self) -> dict[str, Any]:
|
|
112
|
-
return calculate_validation_column_info(self.column_config, self.df)
|
|
112
|
+
return calculate_validation_column_info(self.column_config.name, self.df)
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
|
|
116
116
|
|
|
117
117
|
|
|
118
|
-
ColumnStatisticsCalculatorT: TypeAlias =
|
|
119
|
-
ExpressionColumnStatisticsCalculator
|
|
120
|
-
ValidationColumnStatisticsCalculator
|
|
121
|
-
GeneralColumnStatisticsCalculator
|
|
122
|
-
LLMCodeColumnStatisticsCalculator
|
|
123
|
-
LLMJudgedColumnStatisticsCalculator
|
|
124
|
-
LLMStructuredColumnStatisticsCalculator
|
|
125
|
-
LLMTextColumnStatisticsCalculator
|
|
126
|
-
SamplerColumnStatisticsCalculator
|
|
127
|
-
SeedDatasetColumnStatisticsCalculator
|
|
128
|
-
|
|
118
|
+
ColumnStatisticsCalculatorT: TypeAlias = (
|
|
119
|
+
ExpressionColumnStatisticsCalculator
|
|
120
|
+
| ValidationColumnStatisticsCalculator
|
|
121
|
+
| GeneralColumnStatisticsCalculator
|
|
122
|
+
| LLMCodeColumnStatisticsCalculator
|
|
123
|
+
| LLMJudgedColumnStatisticsCalculator
|
|
124
|
+
| LLMStructuredColumnStatisticsCalculator
|
|
125
|
+
| LLMTextColumnStatisticsCalculator
|
|
126
|
+
| SamplerColumnStatisticsCalculator
|
|
127
|
+
| SeedDatasetColumnStatisticsCalculator
|
|
128
|
+
)
|
|
129
129
|
DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = {
|
|
130
130
|
DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator,
|
|
131
131
|
DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator,
|
|
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
|
|
6
6
|
from functools import cached_property
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
|
+
import pyarrow as pa
|
|
9
10
|
from pydantic import Field, field_validator
|
|
10
11
|
|
|
11
12
|
from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
|
|
@@ -19,10 +20,8 @@ from data_designer.config.column_types import (
|
|
|
19
20
|
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
|
|
20
21
|
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
|
|
21
22
|
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
|
|
22
|
-
from data_designer.engine.
|
|
23
|
-
|
|
24
|
-
MultiColumnConfig,
|
|
25
|
-
)
|
|
23
|
+
from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
|
|
24
|
+
from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
|
|
26
25
|
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
|
|
27
26
|
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
28
27
|
|
|
@@ -68,6 +67,7 @@ class DataDesignerDatasetProfiler:
|
|
|
68
67
|
logger.info("📐 Measuring dataset column statistics:")
|
|
69
68
|
|
|
70
69
|
self._validate_schema_consistency(list(dataset.columns))
|
|
70
|
+
dataset = self._convert_to_pyarrow_backend_if_needed(dataset)
|
|
71
71
|
|
|
72
72
|
column_statistics = []
|
|
73
73
|
for c in self.config.column_configs:
|
|
@@ -100,6 +100,27 @@ class DataDesignerDatasetProfiler:
|
|
|
100
100
|
column_profiles=column_profiles if column_profiles else None,
|
|
101
101
|
)
|
|
102
102
|
|
|
103
|
+
def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
|
|
104
|
+
if not has_pyarrow_backend(dataset):
|
|
105
|
+
try:
|
|
106
|
+
dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
# For ArrowTypeError, the second arg contains the more informative message
|
|
109
|
+
if isinstance(e, pa.lib.ArrowTypeError) and len(e.args) > 1:
|
|
110
|
+
error_msg = str(e.args[1])
|
|
111
|
+
else:
|
|
112
|
+
error_msg = str(e)
|
|
113
|
+
for col in dataset.columns:
|
|
114
|
+
# Make sure column names are clear in the error message
|
|
115
|
+
error_msg = error_msg.replace(col, f"'{col}'")
|
|
116
|
+
logger.warning("⚠️ Unable to convert the dataset to a PyArrow backend")
|
|
117
|
+
logger.warning(f" |-- Conversion Error Message: {error_msg}")
|
|
118
|
+
logger.warning(" |-- This is often due to at least one column having mixed data types")
|
|
119
|
+
logger.warning(
|
|
120
|
+
" |-- Note: Reported data types will be inferred from the first non-null value of each column"
|
|
121
|
+
)
|
|
122
|
+
return dataset
|
|
123
|
+
|
|
103
124
|
def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
|
|
104
125
|
return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
|
|
105
126
|
config=profiler_config, resource_provider=self.resource_provider
|