data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +14 -1
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +5 -4
  31. data_designer/config/processors.py +109 -4
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +39 -9
  42. data_designer/config/utils/visualization.py +62 -15
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +14 -5
  57. data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/processors/drop_columns.py +1 -1
  69. data_designer/engine/processing/processors/registry.py +3 -0
  70. data_designer/engine/processing/processors/schema_transform.py +53 -0
  71. data_designer/engine/processing/utils.py +40 -2
  72. data_designer/engine/registry/base.py +12 -12
  73. data_designer/engine/sampling_gen/constraints.py +1 -2
  74. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  75. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  76. data_designer/engine/sampling_gen/people_gen.py +3 -7
  77. data_designer/engine/validators/base.py +2 -2
  78. data_designer/interface/data_designer.py +12 -0
  79. data_designer/interface/results.py +36 -0
  80. data_designer/logging.py +2 -2
  81. data_designer/plugin_manager.py +3 -3
  82. data_designer/plugins/plugin.py +3 -3
  83. data_designer/plugins/registry.py +2 -2
  84. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
  85. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
  86. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  87. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  88. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Literal, Optional, Union
5
+ from typing import Literal
6
6
 
7
7
  import pandas as pd
8
8
  from pydantic import Field, field_validator, model_validator
@@ -54,12 +54,12 @@ class CategorySamplerParams(ConfigBase):
54
54
  Larger weights result in higher sampling probability for the corresponding value.
55
55
  """
56
56
 
57
- values: list[Union[str, int, float]] = Field(
57
+ values: list[str | int | float] = Field(
58
58
  ...,
59
59
  min_length=1,
60
60
  description="List of possible categorical values that can be sampled from.",
61
61
  )
62
- weights: Optional[list[float]] = Field(
62
+ weights: list[float] | None = Field(
63
63
  default=None,
64
64
  description=(
65
65
  "List of unnormalized probability weights to assigned to each value, in order. "
@@ -134,7 +134,7 @@ class SubcategorySamplerParams(ConfigBase):
134
134
  """
135
135
 
136
136
  category: str = Field(..., description="Name of parent category to this subcategory.")
137
- values: dict[str, list[Union[str, int, float]]] = Field(
137
+ values: dict[str, list[str | int | float]] = Field(
138
138
  ...,
139
139
  description="Mapping from each value of parent category to a list of subcategory values.",
140
140
  )
@@ -214,7 +214,7 @@ class UUIDSamplerParams(ConfigBase):
214
214
  lowercase UUIDs.
215
215
  """
216
216
 
217
- prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.")
217
+ prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
218
218
  short_form: bool = Field(
219
219
  default=False,
220
220
  description="If true, all UUIDs sampled will be truncated at 8 characters.",
@@ -259,7 +259,7 @@ class ScipySamplerParams(ConfigBase):
259
259
  ...,
260
260
  description="Parameters of the scipy.stats distribution given in `dist_name`.",
261
261
  )
262
- decimal_places: Optional[int] = Field(
262
+ decimal_places: int | None = Field(
263
263
  default=None, description="Number of decimal places to round the sampled values to."
264
264
  )
265
265
  sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY
@@ -356,7 +356,7 @@ class GaussianSamplerParams(ConfigBase):
356
356
 
357
357
  mean: float = Field(..., description="Mean of the Gaussian distribution")
358
358
  stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
359
- decimal_places: Optional[int] = Field(
359
+ decimal_places: int | None = Field(
360
360
  default=None, description="Number of decimal places to round the sampled values to."
361
361
  )
362
362
  sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN
@@ -398,7 +398,7 @@ class UniformSamplerParams(ConfigBase):
398
398
 
399
399
  low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
400
400
  high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
401
- decimal_places: Optional[int] = Field(
401
+ decimal_places: int | None = Field(
402
402
  default=None, description="Number of decimal places to round the sampled values to."
403
403
  )
404
404
  sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM
@@ -421,8 +421,8 @@ class PersonSamplerParams(ConfigBase):
421
421
 
422
422
  Attributes:
423
423
  locale: Locale string determining the language and geographic region for synthetic people.
424
- Format: language_COUNTRY (e.g., "en_US", "en_GB", "fr_FR", "de_DE", "es_ES", "ja_JP").
425
- Defaults to "en_US".
424
+ Must be a locale supported by a managed Nemotron Personas dataset. The dataset must
425
+ be downloaded and available in the managed assets directory.
426
426
  sex: If specified, filters to only sample people of the specified sex. Options: "Male" or
427
427
  "Female". If None, samples both sexes.
428
428
  city: If specified, filters to only sample people from the specified city or cities. Can be
@@ -447,11 +447,11 @@ class PersonSamplerParams(ConfigBase):
447
447
  f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
448
448
  ),
449
449
  )
450
- sex: Optional[SexT] = Field(
450
+ sex: SexT | None = Field(
451
451
  default=None,
452
452
  description="If specified, then only synthetic people of the specified sex will be sampled.",
453
453
  )
454
- city: Optional[Union[str, list[str]]] = Field(
454
+ city: str | list[str] | None = Field(
455
455
  default=None,
456
456
  description="If specified, then only synthetic people from these cities will be sampled.",
457
457
  )
@@ -461,7 +461,7 @@ class PersonSamplerParams(ConfigBase):
461
461
  min_length=2,
462
462
  max_length=2,
463
463
  )
464
- select_field_values: Optional[dict[str, list[str]]] = Field(
464
+ select_field_values: dict[str, list[str]] | None = Field(
465
465
  default=None,
466
466
  description=(
467
467
  "Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
@@ -529,11 +529,11 @@ class PersonFromFakerSamplerParams(ConfigBase):
529
529
  "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
530
530
  ),
531
531
  )
532
- sex: Optional[SexT] = Field(
532
+ sex: SexT | None = Field(
533
533
  default=None,
534
534
  description="If specified, then only synthetic people of the specified sex will be sampled.",
535
535
  )
536
- city: Optional[Union[str, list[str]]] = Field(
536
+ city: str | list[str] | None = Field(
537
537
  default=None,
538
538
  description="If specified, then only synthetic people from these cities will be sampled.",
539
539
  )
@@ -585,22 +585,22 @@ class PersonFromFakerSamplerParams(ConfigBase):
585
585
  return value
586
586
 
587
587
 
588
- SamplerParamsT: TypeAlias = Union[
589
- SubcategorySamplerParams,
590
- CategorySamplerParams,
591
- DatetimeSamplerParams,
592
- PersonSamplerParams,
593
- PersonFromFakerSamplerParams,
594
- TimeDeltaSamplerParams,
595
- UUIDSamplerParams,
596
- BernoulliSamplerParams,
597
- BernoulliMixtureSamplerParams,
598
- BinomialSamplerParams,
599
- GaussianSamplerParams,
600
- PoissonSamplerParams,
601
- UniformSamplerParams,
602
- ScipySamplerParams,
603
- ]
588
+ SamplerParamsT: TypeAlias = (
589
+ SubcategorySamplerParams
590
+ | CategorySamplerParams
591
+ | DatetimeSamplerParams
592
+ | PersonSamplerParams
593
+ | PersonFromFakerSamplerParams
594
+ | TimeDeltaSamplerParams
595
+ | UUIDSamplerParams
596
+ | BernoulliSamplerParams
597
+ | BernoulliMixtureSamplerParams
598
+ | BinomialSamplerParams
599
+ | GaussianSamplerParams
600
+ | PoissonSamplerParams
601
+ | UniformSamplerParams
602
+ | ScipySamplerParams
603
+ )
604
604
 
605
605
 
606
606
  def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
@@ -3,7 +3,6 @@
3
3
 
4
4
  from abc import ABC
5
5
  from enum import Enum
6
- from typing import Optional, Union
7
6
 
8
7
  from pydantic import Field, field_validator, model_validator
9
8
  from typing_extensions import Self
@@ -112,7 +111,7 @@ class SeedConfig(ConfigBase):
112
111
 
113
112
  dataset: str
114
113
  sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
115
- selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None
114
+ selection_strategy: IndexRange | PartitionBlock | None = None
116
115
 
117
116
 
118
117
  class SeedDatasetReference(ABC, ConfigBase):
@@ -4,7 +4,6 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  from enum import Enum
7
- from typing import Union
8
7
 
9
8
 
10
9
  class CodeLang(str, Enum):
@@ -26,17 +25,17 @@ class CodeLang(str, Enum):
26
25
  SQL_ANSI = "sql:ansi"
27
26
 
28
27
  @staticmethod
29
- def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]:
28
+ def parse(value: str | CodeLang) -> tuple[str, str | None]:
30
29
  value = value.value if isinstance(value, CodeLang) else value
31
30
  split_vals = value.split(":")
32
31
  return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
33
32
 
34
33
  @staticmethod
35
- def parse_lang(value: Union[str, CodeLang]) -> str:
34
+ def parse_lang(value: str | CodeLang) -> str:
36
35
  return CodeLang.parse(value)[0]
37
36
 
38
37
  @staticmethod
39
- def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]:
38
+ def parse_dialect(value: str | CodeLang) -> str | None:
40
39
  return CodeLang.parse(value)[1]
41
40
 
42
41
  @staticmethod
@@ -58,7 +57,7 @@ SQL_DIALECTS: set[CodeLang] = {
58
57
  ##########################################################
59
58
 
60
59
 
61
- def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str:
60
+ def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
62
61
  """Convert the code language to a syntax lexer for Pygments.
63
62
 
64
63
  Reference: https://pygments.org/docs/lexers/
@@ -97,8 +97,6 @@ DEFAULT_AGE_RANGE = [18, 114]
97
97
  MIN_AGE = 0
98
98
  MAX_AGE = 114
99
99
 
100
- LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"]
101
-
102
100
  US_STATES_AND_MAJOR_TERRITORIES = {
103
101
  # States
104
102
  "AK",
@@ -299,15 +297,40 @@ PREDEFINED_PROVIDERS = [
299
297
  },
300
298
  ]
301
299
 
300
+
301
+ DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
302
+ DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
303
+ DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
304
+ DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
305
+
306
+
302
307
  PREDEFINED_PROVIDERS_MODEL_MAP = {
303
308
  NVIDIA_PROVIDER_NAME: {
304
- "text": "nvidia/nvidia-nemotron-nano-9b-v2",
305
- "reasoning": "openai/gpt-oss-20b",
306
- "vision": "nvidia/nemotron-nano-12b-v2-vl",
309
+ "text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
310
+ "reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
311
+ "vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
312
+ "embedding": {
313
+ "model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
314
+ "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
315
+ },
307
316
  },
308
317
  OPENAI_PROVIDER_NAME: {
309
- "text": "gpt-4.1",
310
- "reasoning": "gpt-5",
311
- "vision": "gpt-5",
318
+ "text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
319
+ "reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
320
+ "vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
321
+ "embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
312
322
  },
313
323
  }
324
+
325
+ # Persona locale metadata - used by the CLI and the person sampler.
326
+ NEMOTRON_PERSONAS_DATASET_SIZES = {
327
+ "en_US": "1.24 GB",
328
+ "en_IN": "2.39 GB",
329
+ "hi_Deva_IN": "4.14 GB",
330
+ "hi_Latn_IN": "2.7 GB",
331
+ "ja_JP": "1.69 GB",
332
+ }
333
+
334
+ LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys())
335
+
336
+ NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-"
@@ -8,7 +8,7 @@ from datetime import date, datetime, timedelta
8
8
  from decimal import Decimal
9
9
  from numbers import Number
10
10
  from pathlib import Path
11
- from typing import Any, Union
11
+ from typing import Any
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
@@ -128,7 +128,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
128
128
  dataframe.to_json(file_path, orient="records", lines=True)
129
129
 
130
130
 
131
- def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path:
131
+ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
132
132
  """Validate that a dataset file path has a valid extension and optionally exists.
133
133
 
134
134
  Args:
@@ -165,7 +165,7 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str)
165
165
  raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
166
166
 
167
167
 
168
- def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
168
+ def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
169
169
  """Load a dataframe from file if a path is given, otherwise return the dataframe.
170
170
 
171
171
  Args:
@@ -197,7 +197,7 @@ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFr
197
197
  raise ValueError(f"Unsupported file format: {dataframe}")
198
198
 
199
199
 
200
- def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
200
+ def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
201
201
  """Return the yaml config as a dict given flexible input types.
202
202
 
203
203
  Args:
@@ -227,7 +227,7 @@ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
227
227
  return yaml_out
228
228
 
229
229
 
230
- def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str:
230
+ def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
231
231
  if isinstance(data, dict):
232
232
  return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
233
233
  elif isinstance(data, list):
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import json
7
7
  from contextlib import contextmanager
8
- from typing import Optional, Union
9
8
 
10
9
  from jinja2 import TemplateSyntaxError, meta
11
10
  from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -58,9 +57,7 @@ def get_prompt_template_keywords(template: str) -> set[str]:
58
57
  return keywords
59
58
 
60
59
 
61
- def json_indent_list_of_strings(
62
- column_names: list[str], *, indent: Optional[Union[int, str]] = None
63
- ) -> Optional[Union[list[str], str]]:
60
+ def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
64
61
  """Convert a list of column names to a JSON string if the list is long.
65
62
 
66
63
  This function helps keep Data Designer's __repr__ output clean and readable.
@@ -3,7 +3,7 @@
3
3
 
4
4
  import numbers
5
5
  from numbers import Number
6
- from typing import Any, Type
6
+ from typing import Any
7
7
 
8
8
  from data_designer.config.utils.constants import REPORTING_PRECISION
9
9
 
@@ -18,7 +18,7 @@ def is_float(val: Any) -> bool:
18
18
 
19
19
  def prepare_number_for_reporting(
20
20
  value: Number,
21
- target_type: Type[Number],
21
+ target_type: type[Number],
22
22
  precision: int = REPORTING_PRECISION,
23
23
  ) -> Number:
24
24
  """Ensure native python types and round to `precision` decimal digits."""
@@ -3,7 +3,7 @@
3
3
 
4
4
  import inspect
5
5
  from enum import Enum
6
- from typing import Any, Literal, Type, get_args, get_origin
6
+ from typing import Any, Literal, get_args, get_origin
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
@@ -56,7 +56,7 @@ def create_str_enum_from_discriminated_type_union(
56
56
  return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
57
57
 
58
58
 
59
- def get_sampler_params() -> dict[str, Type[BaseModel]]:
59
+ def get_sampler_params() -> dict[str, type[BaseModel]]:
60
60
  """Returns a dictionary of sampler parameter classes."""
61
61
  params_cls_list = [
62
62
  params_cls
@@ -83,7 +83,7 @@ def get_sampler_params() -> dict[str, Type[BaseModel]]:
83
83
  return params_cls_dict
84
84
 
85
85
 
86
- def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum:
86
+ def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
87
87
  if not issubclass(enum_type, Enum):
88
88
  raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
89
89
  invalid_enum_value_error = InvalidEnumValueError(
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  from enum import Enum
7
7
  from string import Formatter
8
- from typing import Optional
9
8
 
10
9
  from jinja2 import meta
11
10
  from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -15,10 +14,13 @@ from rich.console import Console, Group
15
14
  from rich.padding import Padding
16
15
  from rich.panel import Panel
17
16
 
18
- from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated
19
- from data_designer.config.processors import ProcessorConfig, ProcessorType
17
+ from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
18
+ from data_designer.config.processors import ProcessorConfigT, ProcessorType
20
19
  from data_designer.config.utils.constants import RICH_CONSOLE_THEME
21
- from data_designer.config.utils.misc import can_run_data_designer_locally
20
+ from data_designer.config.utils.misc import (
21
+ can_run_data_designer_locally,
22
+ get_prompt_template_keywords,
23
+ )
22
24
  from data_designer.config.validator_params import ValidatorType
23
25
 
24
26
 
@@ -42,7 +44,7 @@ class ViolationLevel(str, Enum):
42
44
 
43
45
 
44
46
  class Violation(BaseModel):
45
- column: Optional[str] = None
47
+ column: str | None = None
46
48
  type: ViolationType
47
49
  message: str
48
50
  level: ViolationLevel
@@ -54,7 +56,7 @@ class Violation(BaseModel):
54
56
 
55
57
  def validate_data_designer_config(
56
58
  columns: list[ColumnConfigT],
57
- processor_configs: list[ProcessorConfig],
59
+ processor_configs: list[ProcessorConfigT],
58
60
  allowed_references: list[str],
59
61
  ) -> list[Violation]:
60
62
  violations = []
@@ -63,6 +65,7 @@ def validate_data_designer_config(
63
65
  violations.extend(validate_expression_references(columns=columns, allowed_references=allowed_references))
64
66
  violations.extend(validate_columns_not_all_dropped(columns=columns))
65
67
  violations.extend(validate_drop_columns_processor(columns=columns, processor_configs=processor_configs))
68
+ violations.extend(validate_schema_transform_processor(columns=columns, processor_configs=processor_configs))
66
69
  if not can_run_data_designer_locally():
67
70
  violations.extend(validate_local_only_columns(columns=columns))
68
71
  return violations
@@ -115,7 +118,7 @@ def validate_prompt_templates(
115
118
  ) -> list[Violation]:
116
119
  env = ImmutableSandboxedEnvironment()
117
120
 
118
- columns_with_prompts = [c for c in columns if column_type_is_llm_generated(c.column_type)]
121
+ columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)]
119
122
 
120
123
  violations = []
121
124
  for column in columns_with_prompts:
@@ -269,9 +272,9 @@ def validate_columns_not_all_dropped(
269
272
 
270
273
  def validate_drop_columns_processor(
271
274
  columns: list[ColumnConfigT],
272
- processor_configs: list[ProcessorConfig],
275
+ processor_configs: list[ProcessorConfigT],
273
276
  ) -> list[Violation]:
274
- all_column_names = set([c.name for c in columns])
277
+ all_column_names = {c.name for c in columns}
275
278
  for processor_config in processor_configs:
276
279
  if processor_config.processor_type == ProcessorType.DROP_COLUMNS:
277
280
  invalid_columns = set(processor_config.column_names) - all_column_names
@@ -288,6 +291,33 @@ def validate_drop_columns_processor(
288
291
  return []
289
292
 
290
293
 
294
+ def validate_schema_transform_processor(
295
+ columns: list[ColumnConfigT],
296
+ processor_configs: list[ProcessorConfigT],
297
+ ) -> list[Violation]:
298
+ violations = []
299
+
300
+ all_column_names = {c.name for c in columns}
301
+ for processor_config in processor_configs:
302
+ if processor_config.processor_type == ProcessorType.SCHEMA_TRANSFORM:
303
+ for col, template in processor_config.template.items():
304
+ template_keywords = get_prompt_template_keywords(template)
305
+ invalid_keywords = set(template_keywords) - all_column_names
306
+ if len(invalid_keywords) > 0:
307
+ invalid_keywords = ", ".join([f"'{k}'" for k in invalid_keywords])
308
+ message = f"Ancillary dataset processor attempts to reference columns {invalid_keywords} in the template for '{col}', but the columns are not defined in the dataset."
309
+ violations.append(
310
+ Violation(
311
+ column=None,
312
+ type=ViolationType.INVALID_REFERENCE,
313
+ message=message,
314
+ level=ViolationLevel.ERROR,
315
+ )
316
+ )
317
+
318
+ return violations
319
+
320
+
291
321
  def validate_expression_references(
292
322
  columns: list[ColumnConfigT],
293
323
  allowed_references: list[str],
@@ -8,7 +8,7 @@ import os
8
8
  from collections import OrderedDict
9
9
  from enum import Enum
10
10
  from functools import cached_property
11
- from typing import TYPE_CHECKING, Optional, Union
11
+ from typing import TYPE_CHECKING, Any
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
@@ -36,11 +36,11 @@ if TYPE_CHECKING:
36
36
  console = Console()
37
37
 
38
38
 
39
- def get_nvidia_api_key() -> Optional[str]:
39
+ def get_nvidia_api_key() -> str | None:
40
40
  return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
41
41
 
42
42
 
43
- def get_openai_api_key() -> Optional[str]:
43
+ def get_openai_api_key() -> str | None:
44
44
  return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
45
45
 
46
46
 
@@ -72,13 +72,17 @@ class WithRecordSamplerMixin:
72
72
  else:
73
73
  raise DatasetSampleDisplayError("No valid dataset found in results object.")
74
74
 
75
+ def _has_processor_artifacts(self) -> bool:
76
+ return hasattr(self, "processor_artifacts") and self.processor_artifacts is not None
77
+
75
78
  def display_sample_record(
76
79
  self,
77
- index: Optional[int] = None,
80
+ index: int | None = None,
78
81
  *,
79
82
  hide_seed_columns: bool = False,
80
83
  syntax_highlighting_theme: str = "dracula",
81
- background_color: Optional[str] = None,
84
+ background_color: str | None = None,
85
+ processors_to_display: list[str] | None = None,
82
86
  ) -> None:
83
87
  """Display a sample record from the Data Designer dataset preview.
84
88
 
@@ -90,6 +94,7 @@ class WithRecordSamplerMixin:
90
94
  documentation from `rich` for information about available themes.
91
95
  background_color: Background color to use for the record. See the `Syntax`
92
96
  documentation from `rich` for information about available background colors.
97
+ processors_to_display: List of processors to display the artifacts for. If None, all processors will be displayed.
93
98
  """
94
99
  i = index or self._display_cycle_index
95
100
 
@@ -99,8 +104,25 @@ class WithRecordSamplerMixin:
99
104
  except IndexError:
100
105
  raise DatasetSampleDisplayError(f"Index {i} is out of bounds for dataset of length {num_records}.")
101
106
 
107
+ processor_data_to_display = None
108
+ if self._has_processor_artifacts() and len(self.processor_artifacts) > 0:
109
+ if processors_to_display is None:
110
+ processors_to_display = list(self.processor_artifacts.keys())
111
+
112
+ if len(processors_to_display) > 0:
113
+ processor_data_to_display = {}
114
+ for processor in processors_to_display:
115
+ if (
116
+ isinstance(self.processor_artifacts[processor], list)
117
+ and len(self.processor_artifacts[processor]) == num_records
118
+ ):
119
+ processor_data_to_display[processor] = self.processor_artifacts[processor][i]
120
+ else:
121
+ processor_data_to_display[processor] = self.processor_artifacts[processor]
122
+
102
123
  display_sample_record(
103
124
  record=record,
125
+ processor_data_to_display=processor_data_to_display,
104
126
  config_builder=self._config_builder,
105
127
  background_color=background_color,
106
128
  syntax_highlighting_theme=syntax_highlighting_theme,
@@ -112,11 +134,11 @@ class WithRecordSamplerMixin:
112
134
 
113
135
 
114
136
  def create_rich_histogram_table(
115
- data: dict[str, Union[int, float]],
137
+ data: dict[str, int | float],
116
138
  column_names: tuple[int, int],
117
139
  name_style: str = ColorPalette.BLUE.value,
118
140
  value_style: str = ColorPalette.TEAL.value,
119
- title: Optional[str] = None,
141
+ title: str | None = None,
120
142
  **kwargs,
121
143
  ) -> Table:
122
144
  table = Table(title=title, **kwargs)
@@ -132,11 +154,12 @@ def create_rich_histogram_table(
132
154
 
133
155
 
134
156
  def display_sample_record(
135
- record: Union[dict, pd.Series, pd.DataFrame],
157
+ record: dict | pd.Series | pd.DataFrame,
136
158
  config_builder: DataDesignerConfigBuilder,
137
- background_color: Optional[str] = None,
159
+ processor_data_to_display: dict[str, list[str] | str] | None = None,
160
+ background_color: str | None = None,
138
161
  syntax_highlighting_theme: str = "dracula",
139
- record_index: Optional[int] = None,
162
+ record_index: int | None = None,
140
163
  hide_seed_columns: bool = False,
141
164
  ):
142
165
  if isinstance(record, (dict, pd.Series)):
@@ -171,6 +194,7 @@ def display_sample_record(
171
194
  + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
172
195
  + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
173
196
  + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
197
+ + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
174
198
  )
175
199
  if len(non_code_columns) > 0:
176
200
  table = Table(title="Generated Columns", **table_kws)
@@ -178,6 +202,10 @@ def display_sample_record(
178
202
  table.add_column("Value")
179
203
  for col in non_code_columns:
180
204
  if not col.drop:
205
+ if col.column_type == DataDesignerColumnType.EMBEDDING:
206
+ record[col.name]["embeddings"] = [
207
+ get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
208
+ ]
181
209
  table.add_row(col.name, convert_to_row_element(record[col.name]))
182
210
  render_list.append(pad_console_element(table))
183
211
 
@@ -230,6 +258,15 @@ def display_sample_record(
230
258
  table.add_row(*row)
231
259
  render_list.append(pad_console_element(table, (1, 0, 1, 0)))
232
260
 
261
+ if processor_data_to_display and len(processor_data_to_display) > 0:
262
+ for processor_name, processor_data in processor_data_to_display.items():
263
+ table = Table(title=f"Processor Outputs: {processor_name}", **table_kws)
264
+ table.add_column("Name")
265
+ table.add_column("Value")
266
+ for col, value in processor_data.items():
267
+ table.add_row(col, convert_to_row_element(value))
268
+ render_list.append(pad_console_element(table, (1, 0, 1, 0)))
269
+
233
270
  if record_index is not None:
234
271
  index_label = Text(f"[index: {record_index}]", justify="center")
235
272
  render_list.append(index_label)
@@ -237,9 +274,19 @@ def display_sample_record(
237
274
  console.print(Group(*render_list), markup=False)
238
275
 
239
276
 
277
+ def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
278
+ if max_items <= 0:
279
+ raise ValueError("max_items must be greater than 0")
280
+ if len(long_list) > max_items:
281
+ truncated_part = long_list[:max_items]
282
+ return f"[{', '.join(str(x) for x in truncated_part)}, ...]"
283
+ else:
284
+ return str(long_list)
285
+
286
+
240
287
  def display_sampler_table(
241
288
  sampler_params: dict[SamplerType, ConfigBase],
242
- title: Optional[str] = None,
289
+ title: str | None = None,
243
290
  ) -> None:
244
291
  table = Table(expand=True)
245
292
  table.add_column("Type")
@@ -274,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
274
321
  table_model_configs.add_column("Alias")
275
322
  table_model_configs.add_column("Model")
276
323
  table_model_configs.add_column("Provider")
277
- table_model_configs.add_column("Temperature")
278
- table_model_configs.add_column("Top P")
324
+ table_model_configs.add_column("Inference Parameters")
279
325
  for model_config in model_configs:
326
+ params_display = model_config.inference_parameters.format_for_display()
327
+
280
328
  table_model_configs.add_row(
281
329
  model_config.alias,
282
330
  model_config.model,
283
331
  model_config.provider,
284
- str(model_config.inference_parameters.temperature),
285
- str(model_config.inference_parameters.top_p),
332
+ params_display,
286
333
  )
287
334
  group_args: list = [Rule(title="Model Configs"), table_model_configs]
288
335
  if len(model_configs) == 0: