data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/utils.py +40 -2
  69. data_designer/engine/registry/base.py +12 -12
  70. data_designer/engine/sampling_gen/constraints.py +1 -2
  71. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  72. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  73. data_designer/engine/sampling_gen/people_gen.py +3 -7
  74. data_designer/engine/validators/base.py +2 -2
  75. data_designer/logging.py +2 -2
  76. data_designer/plugin_manager.py +3 -3
  77. data_designer/plugins/plugin.py +3 -3
  78. data_designer/plugins/registry.py +2 -2
  79. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,6 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  from enum import Enum
7
- from typing import Union
8
7
 
9
8
 
10
9
  class CodeLang(str, Enum):
@@ -26,17 +25,17 @@ class CodeLang(str, Enum):
26
25
  SQL_ANSI = "sql:ansi"
27
26
 
28
27
  @staticmethod
29
- def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]:
28
+ def parse(value: str | CodeLang) -> tuple[str, str | None]:
30
29
  value = value.value if isinstance(value, CodeLang) else value
31
30
  split_vals = value.split(":")
32
31
  return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
33
32
 
34
33
  @staticmethod
35
- def parse_lang(value: Union[str, CodeLang]) -> str:
34
+ def parse_lang(value: str | CodeLang) -> str:
36
35
  return CodeLang.parse(value)[0]
37
36
 
38
37
  @staticmethod
39
- def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]:
38
+ def parse_dialect(value: str | CodeLang) -> str | None:
40
39
  return CodeLang.parse(value)[1]
41
40
 
42
41
  @staticmethod
@@ -58,7 +57,7 @@ SQL_DIALECTS: set[CodeLang] = {
58
57
  ##########################################################
59
58
 
60
59
 
61
- def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str:
60
+ def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
62
61
  """Convert the code language to a syntax lexer for Pygments.
63
62
 
64
63
  Reference: https://pygments.org/docs/lexers/
@@ -97,8 +97,6 @@ DEFAULT_AGE_RANGE = [18, 114]
97
97
  MIN_AGE = 0
98
98
  MAX_AGE = 114
99
99
 
100
- LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"]
101
-
102
100
  US_STATES_AND_MAJOR_TERRITORIES = {
103
101
  # States
104
102
  "AK",
@@ -299,15 +297,40 @@ PREDEFINED_PROVIDERS = [
299
297
  },
300
298
  ]
301
299
 
300
+
301
+ DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
302
+ DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
303
+ DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
304
+ DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
305
+
306
+
302
307
  PREDEFINED_PROVIDERS_MODEL_MAP = {
303
308
  NVIDIA_PROVIDER_NAME: {
304
- "text": "nvidia/nvidia-nemotron-nano-9b-v2",
305
- "reasoning": "openai/gpt-oss-20b",
306
- "vision": "nvidia/nemotron-nano-12b-v2-vl",
309
+ "text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
310
+ "reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
311
+ "vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
312
+ "embedding": {
313
+ "model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
314
+ "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
315
+ },
307
316
  },
308
317
  OPENAI_PROVIDER_NAME: {
309
- "text": "gpt-4.1",
310
- "reasoning": "gpt-5",
311
- "vision": "gpt-5",
318
+ "text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
319
+ "reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
320
+ "vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
321
+ "embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
312
322
  },
313
323
  }
324
+
325
+ # Persona locale metadata - used by the CLI and the person sampler.
326
+ NEMOTRON_PERSONAS_DATASET_SIZES = {
327
+ "en_US": "1.24 GB",
328
+ "en_IN": "2.39 GB",
329
+ "hi_Deva_IN": "4.14 GB",
330
+ "hi_Latn_IN": "2.7 GB",
331
+ "ja_JP": "1.69 GB",
332
+ }
333
+
334
+ LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys())
335
+
336
+ NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-"
@@ -8,7 +8,7 @@ from datetime import date, datetime, timedelta
8
8
  from decimal import Decimal
9
9
  from numbers import Number
10
10
  from pathlib import Path
11
- from typing import Any, Union
11
+ from typing import Any
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
@@ -128,7 +128,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
128
128
  dataframe.to_json(file_path, orient="records", lines=True)
129
129
 
130
130
 
131
- def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path:
131
+ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
132
132
  """Validate that a dataset file path has a valid extension and optionally exists.
133
133
 
134
134
  Args:
@@ -165,7 +165,7 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str)
165
165
  raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
166
166
 
167
167
 
168
- def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
168
+ def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
169
169
  """Load a dataframe from file if a path is given, otherwise return the dataframe.
170
170
 
171
171
  Args:
@@ -197,7 +197,7 @@ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFr
197
197
  raise ValueError(f"Unsupported file format: {dataframe}")
198
198
 
199
199
 
200
- def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
200
+ def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
201
201
  """Return the yaml config as a dict given flexible input types.
202
202
 
203
203
  Args:
@@ -227,7 +227,7 @@ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
227
227
  return yaml_out
228
228
 
229
229
 
230
- def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str:
230
+ def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
231
231
  if isinstance(data, dict):
232
232
  return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
233
233
  elif isinstance(data, list):
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import json
7
7
  from contextlib import contextmanager
8
- from typing import Optional, Union
9
8
 
10
9
  from jinja2 import TemplateSyntaxError, meta
11
10
  from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -58,9 +57,7 @@ def get_prompt_template_keywords(template: str) -> set[str]:
58
57
  return keywords
59
58
 
60
59
 
61
- def json_indent_list_of_strings(
62
- column_names: list[str], *, indent: Optional[Union[int, str]] = None
63
- ) -> Optional[Union[list[str], str]]:
60
+ def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
64
61
  """Convert a list of column names to a JSON string if the list is long.
65
62
 
66
63
  This function helps keep Data Designer's __repr__ output clean and readable.
@@ -3,7 +3,7 @@
3
3
 
4
4
  import numbers
5
5
  from numbers import Number
6
- from typing import Any, Type
6
+ from typing import Any
7
7
 
8
8
  from data_designer.config.utils.constants import REPORTING_PRECISION
9
9
 
@@ -18,7 +18,7 @@ def is_float(val: Any) -> bool:
18
18
 
19
19
  def prepare_number_for_reporting(
20
20
  value: Number,
21
- target_type: Type[Number],
21
+ target_type: type[Number],
22
22
  precision: int = REPORTING_PRECISION,
23
23
  ) -> Number:
24
24
  """Ensure native python types and round to `precision` decimal digits."""
@@ -3,7 +3,7 @@
3
3
 
4
4
  import inspect
5
5
  from enum import Enum
6
- from typing import Any, Literal, Type, get_args, get_origin
6
+ from typing import Any, Literal, get_args, get_origin
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
@@ -56,7 +56,7 @@ def create_str_enum_from_discriminated_type_union(
56
56
  return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
57
57
 
58
58
 
59
- def get_sampler_params() -> dict[str, Type[BaseModel]]:
59
+ def get_sampler_params() -> dict[str, type[BaseModel]]:
60
60
  """Returns a dictionary of sampler parameter classes."""
61
61
  params_cls_list = [
62
62
  params_cls
@@ -83,7 +83,7 @@ def get_sampler_params() -> dict[str, Type[BaseModel]]:
83
83
  return params_cls_dict
84
84
 
85
85
 
86
- def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum:
86
+ def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
87
87
  if not issubclass(enum_type, Enum):
88
88
  raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
89
89
  invalid_enum_value_error = InvalidEnumValueError(
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  from enum import Enum
7
7
  from string import Formatter
8
- from typing import Optional
9
8
 
10
9
  from jinja2 import meta
11
10
  from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -15,8 +14,8 @@ from rich.console import Console, Group
15
14
  from rich.padding import Padding
16
15
  from rich.panel import Panel
17
16
 
18
- from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated
19
- from data_designer.config.processors import ProcessorConfig, ProcessorType
17
+ from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
18
+ from data_designer.config.processors import ProcessorConfigT, ProcessorType
20
19
  from data_designer.config.utils.constants import RICH_CONSOLE_THEME
21
20
  from data_designer.config.utils.misc import (
22
21
  can_run_data_designer_locally,
@@ -45,7 +44,7 @@ class ViolationLevel(str, Enum):
45
44
 
46
45
 
47
46
  class Violation(BaseModel):
48
- column: Optional[str] = None
47
+ column: str | None = None
49
48
  type: ViolationType
50
49
  message: str
51
50
  level: ViolationLevel
@@ -57,7 +56,7 @@ class Violation(BaseModel):
57
56
 
58
57
  def validate_data_designer_config(
59
58
  columns: list[ColumnConfigT],
60
- processor_configs: list[ProcessorConfig],
59
+ processor_configs: list[ProcessorConfigT],
61
60
  allowed_references: list[str],
62
61
  ) -> list[Violation]:
63
62
  violations = []
@@ -119,7 +118,7 @@ def validate_prompt_templates(
119
118
  ) -> list[Violation]:
120
119
  env = ImmutableSandboxedEnvironment()
121
120
 
122
- columns_with_prompts = [c for c in columns if column_type_is_llm_generated(c.column_type)]
121
+ columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)]
123
122
 
124
123
  violations = []
125
124
  for column in columns_with_prompts:
@@ -273,7 +272,7 @@ def validate_columns_not_all_dropped(
273
272
 
274
273
  def validate_drop_columns_processor(
275
274
  columns: list[ColumnConfigT],
276
- processor_configs: list[ProcessorConfig],
275
+ processor_configs: list[ProcessorConfigT],
277
276
  ) -> list[Violation]:
278
277
  all_column_names = {c.name for c in columns}
279
278
  for processor_config in processor_configs:
@@ -294,7 +293,7 @@ def validate_drop_columns_processor(
294
293
 
295
294
  def validate_schema_transform_processor(
296
295
  columns: list[ColumnConfigT],
297
- processor_configs: list[ProcessorConfig],
296
+ processor_configs: list[ProcessorConfigT],
298
297
  ) -> list[Violation]:
299
298
  violations = []
300
299
 
@@ -8,7 +8,7 @@ import os
8
8
  from collections import OrderedDict
9
9
  from enum import Enum
10
10
  from functools import cached_property
11
- from typing import TYPE_CHECKING, Optional, Union
11
+ from typing import TYPE_CHECKING, Any
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
@@ -36,11 +36,11 @@ if TYPE_CHECKING:
36
36
  console = Console()
37
37
 
38
38
 
39
- def get_nvidia_api_key() -> Optional[str]:
39
+ def get_nvidia_api_key() -> str | None:
40
40
  return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
41
41
 
42
42
 
43
- def get_openai_api_key() -> Optional[str]:
43
+ def get_openai_api_key() -> str | None:
44
44
  return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
45
45
 
46
46
 
@@ -77,12 +77,12 @@ class WithRecordSamplerMixin:
77
77
 
78
78
  def display_sample_record(
79
79
  self,
80
- index: Optional[int] = None,
80
+ index: int | None = None,
81
81
  *,
82
82
  hide_seed_columns: bool = False,
83
83
  syntax_highlighting_theme: str = "dracula",
84
- background_color: Optional[str] = None,
85
- processors_to_display: Optional[list[str]] = None,
84
+ background_color: str | None = None,
85
+ processors_to_display: list[str] | None = None,
86
86
  ) -> None:
87
87
  """Display a sample record from the Data Designer dataset preview.
88
88
 
@@ -134,11 +134,11 @@ class WithRecordSamplerMixin:
134
134
 
135
135
 
136
136
  def create_rich_histogram_table(
137
- data: dict[str, Union[int, float]],
137
+ data: dict[str, int | float],
138
138
  column_names: tuple[int, int],
139
139
  name_style: str = ColorPalette.BLUE.value,
140
140
  value_style: str = ColorPalette.TEAL.value,
141
- title: Optional[str] = None,
141
+ title: str | None = None,
142
142
  **kwargs,
143
143
  ) -> Table:
144
144
  table = Table(title=title, **kwargs)
@@ -154,12 +154,12 @@ def create_rich_histogram_table(
154
154
 
155
155
 
156
156
  def display_sample_record(
157
- record: Union[dict, pd.Series, pd.DataFrame],
157
+ record: dict | pd.Series | pd.DataFrame,
158
158
  config_builder: DataDesignerConfigBuilder,
159
- processor_data_to_display: Optional[dict[str, Union[list[str], str]]] = None,
160
- background_color: Optional[str] = None,
159
+ processor_data_to_display: dict[str, list[str] | str] | None = None,
160
+ background_color: str | None = None,
161
161
  syntax_highlighting_theme: str = "dracula",
162
- record_index: Optional[int] = None,
162
+ record_index: int | None = None,
163
163
  hide_seed_columns: bool = False,
164
164
  ):
165
165
  if isinstance(record, (dict, pd.Series)):
@@ -194,6 +194,7 @@ def display_sample_record(
194
194
  + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
195
195
  + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
196
196
  + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
197
+ + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
197
198
  )
198
199
  if len(non_code_columns) > 0:
199
200
  table = Table(title="Generated Columns", **table_kws)
@@ -201,6 +202,10 @@ def display_sample_record(
201
202
  table.add_column("Value")
202
203
  for col in non_code_columns:
203
204
  if not col.drop:
205
+ if col.column_type == DataDesignerColumnType.EMBEDDING:
206
+ record[col.name]["embeddings"] = [
207
+ get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
208
+ ]
204
209
  table.add_row(col.name, convert_to_row_element(record[col.name]))
205
210
  render_list.append(pad_console_element(table))
206
211
 
@@ -269,9 +274,19 @@ def display_sample_record(
269
274
  console.print(Group(*render_list), markup=False)
270
275
 
271
276
 
277
+ def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
278
+ if max_items <= 0:
279
+ raise ValueError("max_items must be greater than 0")
280
+ if len(long_list) > max_items:
281
+ truncated_part = long_list[:max_items]
282
+ return f"[{', '.join(str(x) for x in truncated_part)}, ...]"
283
+ else:
284
+ return str(long_list)
285
+
286
+
272
287
  def display_sampler_table(
273
288
  sampler_params: dict[SamplerType, ConfigBase],
274
- title: Optional[str] = None,
289
+ title: str | None = None,
275
290
  ) -> None:
276
291
  table = Table(expand=True)
277
292
  table.add_column("Type")
@@ -306,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
306
321
  table_model_configs.add_column("Alias")
307
322
  table_model_configs.add_column("Model")
308
323
  table_model_configs.add_column("Provider")
309
- table_model_configs.add_column("Temperature")
310
- table_model_configs.add_column("Top P")
324
+ table_model_configs.add_column("Inference Parameters")
311
325
  for model_config in model_configs:
326
+ params_display = model_config.inference_parameters.format_for_display()
327
+
312
328
  table_model_configs.add_row(
313
329
  model_config.alias,
314
330
  model_config.model,
315
331
  model_config.provider,
316
- str(model_config.inference_parameters.temperature),
317
- str(model_config.inference_parameters.top_p),
332
+ params_display,
318
333
  )
319
334
  group_args: list = [Rule(title="Model Configs"), table_model_configs]
320
335
  if len(model_configs) == 0:
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Any, Optional, Union
5
+ from typing import Any
6
6
 
7
7
  from pydantic import Field, field_serializer, model_validator
8
8
  from typing_extensions import Self, TypeAlias
@@ -51,7 +51,7 @@ class LocalCallableValidatorParams(ConfigBase):
51
51
  validation_function: Any = Field(
52
52
  description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
53
53
  )
54
- output_schema: Optional[dict[str, Any]] = Field(
54
+ output_schema: dict[str, Any] | None = Field(
55
55
  default=None, description="Expected schema for local callable validator's output"
56
56
  )
57
57
 
@@ -80,7 +80,7 @@ class RemoteValidatorParams(ConfigBase):
80
80
  """
81
81
 
82
82
  endpoint_url: str = Field(description="URL of the remote endpoint")
83
- output_schema: Optional[dict[str, Any]] = Field(
83
+ output_schema: dict[str, Any] | None = Field(
84
84
  default=None, description="Expected schema for remote validator's output"
85
85
  )
86
86
  timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request")
@@ -89,8 +89,4 @@ class RemoteValidatorParams(ConfigBase):
89
89
  max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make")
90
90
 
91
91
 
92
- ValidatorParamsT: TypeAlias = Union[
93
- CodeValidatorParams,
94
- LocalCallableValidatorParams,
95
- RemoteValidatorParams,
96
- ]
92
+ ValidatorParamsT: TypeAlias = CodeValidatorParams | LocalCallableValidatorParams | RemoteValidatorParams
@@ -7,7 +7,6 @@ import logging
7
7
  from abc import ABC, abstractmethod
8
8
 
9
9
  import pandas as pd
10
- import pyarrow as pa
11
10
  from pydantic import BaseModel, model_validator
12
11
  from typing_extensions import Self
13
12
 
@@ -29,12 +28,6 @@ class ColumnConfigWithDataFrame(ConfigBase):
29
28
  raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
30
29
  return self
31
30
 
32
- @model_validator(mode="after")
33
- def ensure_pyarrow_backend(self) -> Self:
34
- if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
35
- self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
36
- return self
37
-
38
31
  def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
39
32
  return (self.column_config, self.df)
40
33
 
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
  import random
8
- from typing import Union
9
8
 
10
9
  from data_designer.config.analysis.column_profilers import (
11
10
  JudgeScoreProfilerConfig,
@@ -69,7 +68,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
69
68
  )
70
69
 
71
70
  for score in column_config.scores:
72
- score_name = score.name.lower()
71
+ score_name = score.name
73
72
  logger.info(f"{random.choice(['👩‍⚖️', '👨‍⚖️'])} Summarizing LLM-as-judge score: '{score_name}'")
74
73
  score_sample = sample_scores_and_reasoning(
75
74
  scores=score_distributions.scores[score_name],
@@ -96,7 +95,7 @@ class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
96
95
  name: str,
97
96
  sample: list[JudgeScoreSample],
98
97
  histogram: CategoricalHistogramData,
99
- distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue],
98
+ distribution: CategoricalDistribution | NumericalDistribution | MissingValue,
100
99
  distribution_type: ColumnDistributionType,
101
100
  ) -> JudgeScoreSummary:
102
101
  if isinstance(distribution, MissingValue) or not sample:
@@ -4,7 +4,7 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  import logging
7
- from typing import Any, Type, TypeAlias, Union
7
+ from typing import Any, TypeAlias
8
8
 
9
9
  import pandas as pd
10
10
  from pydantic import BaseModel
@@ -41,7 +41,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
41
41
  return self.column_config_with_df.df
42
42
 
43
43
  @property
44
- def column_statistics_type(self) -> Type[ColumnStatisticsT]:
44
+ def column_statistics_type(self) -> type[ColumnStatisticsT]:
45
45
  return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics)
46
46
 
47
47
  def calculate(self) -> Self:
@@ -59,7 +59,7 @@ class GeneralColumnStatisticsCalculator(BaseModel):
59
59
  )
60
60
 
61
61
  def calculate_general_column_info(self) -> dict[str, Any]:
62
- return calculate_general_column_info(self.column_config, self.df)
62
+ return calculate_general_column_info(self.column_config.name, self.df)
63
63
 
64
64
  def __repr__(self) -> str:
65
65
  params = []
@@ -93,7 +93,7 @@ class SamplerColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
93
93
  return (
94
94
  {
95
95
  "sampler_type": SamplerType(self.column_config.sampler_type),
96
- **calculate_column_distribution(self.column_config, self.df, dist_type),
96
+ **calculate_column_distribution(self.column_config.name, self.df, dist_type),
97
97
  }
98
98
  if make_dist
99
99
  else {
@@ -109,23 +109,23 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
109
109
 
110
110
  class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
111
111
  def calculate_validation_column_info(self) -> dict[str, Any]:
112
- return calculate_validation_column_info(self.column_config, self.df)
112
+ return calculate_validation_column_info(self.column_config.name, self.df)
113
113
 
114
114
 
115
115
  class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
116
116
 
117
117
 
118
- ColumnStatisticsCalculatorT: TypeAlias = Union[
119
- ExpressionColumnStatisticsCalculator,
120
- ValidationColumnStatisticsCalculator,
121
- GeneralColumnStatisticsCalculator,
122
- LLMCodeColumnStatisticsCalculator,
123
- LLMJudgedColumnStatisticsCalculator,
124
- LLMStructuredColumnStatisticsCalculator,
125
- LLMTextColumnStatisticsCalculator,
126
- SamplerColumnStatisticsCalculator,
127
- SeedDatasetColumnStatisticsCalculator,
128
- ]
118
+ ColumnStatisticsCalculatorT: TypeAlias = (
119
+ ExpressionColumnStatisticsCalculator
120
+ | ValidationColumnStatisticsCalculator
121
+ | GeneralColumnStatisticsCalculator
122
+ | LLMCodeColumnStatisticsCalculator
123
+ | LLMJudgedColumnStatisticsCalculator
124
+ | LLMStructuredColumnStatisticsCalculator
125
+ | LLMTextColumnStatisticsCalculator
126
+ | SamplerColumnStatisticsCalculator
127
+ | SeedDatasetColumnStatisticsCalculator
128
+ )
129
129
  DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = {
130
130
  DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator,
131
131
  DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator,
@@ -6,6 +6,7 @@ from collections.abc import Sequence
6
6
  from functools import cached_property
7
7
 
8
8
  import pandas as pd
9
+ import pyarrow as pa
9
10
  from pydantic import Field, field_validator
10
11
 
11
12
  from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
@@ -19,10 +20,8 @@ from data_designer.config.column_types import (
19
20
  from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
20
21
  from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
21
22
  from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
22
- from data_designer.engine.dataset_builders.multi_column_configs import (
23
- DatasetBuilderColumnConfigT,
24
- MultiColumnConfig,
25
- )
23
+ from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
24
+ from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
26
25
  from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
27
26
  from data_designer.engine.resources.resource_provider import ResourceProvider
28
27
 
@@ -68,6 +67,7 @@ class DataDesignerDatasetProfiler:
68
67
  logger.info("📐 Measuring dataset column statistics:")
69
68
 
70
69
  self._validate_schema_consistency(list(dataset.columns))
70
+ dataset = self._convert_to_pyarrow_backend_if_needed(dataset)
71
71
 
72
72
  column_statistics = []
73
73
  for c in self.config.column_configs:
@@ -100,6 +100,27 @@ class DataDesignerDatasetProfiler:
100
100
  column_profiles=column_profiles if column_profiles else None,
101
101
  )
102
102
 
103
+ def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
104
+ if not has_pyarrow_backend(dataset):
105
+ try:
106
+ dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
107
+ except Exception as e:
108
+ # For ArrowTypeError, the second arg contains the more informative message
109
+ if isinstance(e, pa.lib.ArrowTypeError) and len(e.args) > 1:
110
+ error_msg = str(e.args[1])
111
+ else:
112
+ error_msg = str(e)
113
+ for col in dataset.columns:
114
+ # Make sure column names are clear in the error message
115
+ error_msg = error_msg.replace(col, f"'{col}'")
116
+ logger.warning("⚠️ Unable to convert the dataset to a PyArrow backend")
117
+ logger.warning(f" |-- Conversion Error Message: {error_msg}")
118
+ logger.warning(" |-- This is often due to at least one column having mixed data types")
119
+ logger.warning(
120
+ " |-- Note: Reported data types will be inferred from the first non-null value of each column"
121
+ )
122
+ return dataset
123
+
103
124
  def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
104
125
  return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
105
126
  config=profiler_config, resource_provider=self.resource_provider