data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,116 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import functools
7
+ import logging
8
+
9
+ from data_designer.config.column_configs import (
10
+ LLMCodeColumnConfig,
11
+ LLMJudgeColumnConfig,
12
+ LLMStructuredColumnConfig,
13
+ LLMTextColumnConfig,
14
+ )
15
+ from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX
16
+ from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
17
+ from data_designer.engine.column_generators.utils.prompt_renderer import (
18
+ PromptType,
19
+ RecordBasedPromptRenderer,
20
+ create_response_recipe,
21
+ )
22
+ from data_designer.engine.configurable_task import TaskConfigT
23
+ from data_designer.engine.models.recipes.base import ResponseRecipe
24
+ from data_designer.engine.processing.utils import deserialize_json_values
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]):
30
+ @staticmethod
31
+ def get_generation_strategy() -> GenerationStrategy:
32
+ return GenerationStrategy.CELL_BY_CELL
33
+
34
+ @functools.cached_property
35
+ def response_recipe(self) -> ResponseRecipe:
36
+ return create_response_recipe(self.config, self.model_config)
37
+
38
+ @property
39
+ def max_conversation_correction_steps(self) -> int:
40
+ return self.resource_provider.run_config.max_conversation_correction_steps
41
+
42
+ @property
43
+ def max_conversation_restarts(self) -> int:
44
+ return self.resource_provider.run_config.max_conversation_restarts
45
+
46
+ @functools.cached_property
47
+ def prompt_renderer(self) -> RecordBasedPromptRenderer:
48
+ return RecordBasedPromptRenderer(
49
+ response_recipe=self.response_recipe,
50
+ error_message_context={
51
+ "column_name": self.config.name,
52
+ "column_type": self.config.column_type,
53
+ "model_alias": self.config.model_alias,
54
+ },
55
+ )
56
+
57
+ def generate(self, data: dict) -> dict:
58
+ # Deserialize input data from previous columns so Jinja2 templates can access nested fields
59
+ # Example: If prev column stored '{"key": "value"}', templates can use {{ prev_column.key }}
60
+ # Note: This creates a new dict and doesn't mutate the original `data` argument
61
+ deserialized_record = deserialize_json_values(data)
62
+
63
+ multi_modal_context = None
64
+ if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
65
+ multi_modal_context = []
66
+ for context in self.config.multi_modal_context:
67
+ multi_modal_context.extend(context.get_contexts(deserialized_record))
68
+
69
+ response, trace = self.model.generate(
70
+ prompt=self.prompt_renderer.render(
71
+ record=deserialized_record,
72
+ prompt_template=self.config.prompt,
73
+ prompt_type=PromptType.USER_PROMPT,
74
+ ),
75
+ system_prompt=self.prompt_renderer.render(
76
+ record=deserialized_record,
77
+ prompt_template=self.config.system_prompt,
78
+ prompt_type=PromptType.SYSTEM_PROMPT,
79
+ ),
80
+ parser=self.response_recipe.parse,
81
+ multi_modal_context=multi_modal_context,
82
+ max_correction_steps=self.max_conversation_correction_steps,
83
+ max_conversation_restarts=self.max_conversation_restarts,
84
+ purpose=f"running generation for column '{self.config.name}'",
85
+ )
86
+
87
+ serialized_output = self.response_recipe.serialize_output(response)
88
+ data[self.config.name] = self._process_serialized_output(serialized_output)
89
+
90
+ should_save_trace = (
91
+ self.config.with_trace or self.resource_provider.run_config.debug_override_save_all_column_traces
92
+ )
93
+ if should_save_trace:
94
+ data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
95
+
96
+ return data
97
+
98
+ def _process_serialized_output(self, serialized_output: str) -> str | dict | list:
99
+ """Process the serialized output from the model. Subclasses can override to customize deserialization."""
100
+ return serialized_output
101
+
102
+
103
+ class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]): ...
104
+
105
+
106
+ class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]): ...
107
+
108
+
109
+ class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]):
110
+ def _process_serialized_output(self, serialized_output: str) -> dict | list:
111
+ return deserialize_json_values(serialized_output)
112
+
113
+
114
+ class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]):
115
+ def _process_serialized_output(self, serialized_output: str) -> dict | list:
116
+ return deserialize_json_values(serialized_output)
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ import random
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, Callable
10
+
11
+ from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
12
+ from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
13
+ from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
14
+ from data_designer.engine.processing.utils import concat_datasets
15
+ from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
16
+ from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
17
+ from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
18
+ from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
19
+ from data_designer.lazy_heavy_imports import pd
20
+
21
+ if TYPE_CHECKING:
22
+ import pandas as pd
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
28
+ @staticmethod
29
+ def get_generation_strategy() -> GenerationStrategy:
30
+ return GenerationStrategy.FULL_COLUMN
31
+
32
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
33
+ df_samplers = self.generate_from_scratch(len(data))
34
+ return concat_datasets([data, df_samplers])
35
+
36
+ def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
37
+ sampling_generator = self._prepare_for_generation(num_records)
38
+ return sampling_generator.generate(num_records)
39
+
40
+ @property
41
+ def _needs_person_generator(self) -> bool:
42
+ columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
43
+ return any(c.params.locale in LOCALES_WITH_MANAGED_DATASETS for c in columns)
44
+
45
+ @property
46
+ def _person_generator_loader(self) -> Callable[[bool], ManagedDatasetGenerator]:
47
+ return partial(load_person_data_sampler, blob_storage=self.resource_provider.blob_storage)
48
+
49
+ def _create_sampling_dataset_generator(self) -> SamplingDatasetGenerator:
50
+ return SamplingDatasetGenerator(
51
+ sampler_columns=self.config,
52
+ person_generator_loader=(self._person_generator_loader if self._needs_person_generator else None),
53
+ )
54
+
55
+ def _log_person_generation_if_needed(self) -> None:
56
+ if self._needs_person_generator:
57
+ columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
58
+ emoji = random.choice(["🧑‍🎨", "🙋‍♂️", "🙋‍♀️", "🧑‍🚀", "👩‍🎤", "👨‍🍳", "👩‍🔬", "👨‍💻", "👩‍💼"])
59
+ log_msg = f"🎲 {emoji} Initializing person generation"
60
+ if any(c.params.with_synthetic_personas for c in columns):
61
+ log_msg += " ⚡️ with synthetic personas ⚡️"
62
+ logger.info(log_msg)
63
+
64
+ def _prepare_for_generation(self, num_records: int) -> SamplingDatasetGenerator:
65
+ logger.info(
66
+ f"🎲 Preparing samplers to generate {num_records} records across {len(self.config.columns)} columns"
67
+ )
68
+ self._log_person_generation_if_needed()
69
+ return self._create_sampling_dataset_generator()
@@ -0,0 +1,144 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import functools
7
+ import logging
8
+ from typing import TYPE_CHECKING
9
+
10
+ from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
11
+ from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
12
+ from data_designer.engine.column_generators.utils.errors import SeedDatasetError
13
+ from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
14
+ from data_designer.engine.processing.utils import concat_datasets
15
+ from data_designer.lazy_heavy_imports import duckdb, pd
16
+
17
+ if TYPE_CHECKING:
18
+ import duckdb
19
+ import pandas as pd
20
+
21
+ MAX_ZERO_RECORD_RESPONSE_FACTOR = 2
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColumnConfig]):
27
+ @staticmethod
28
+ def get_generation_strategy() -> GenerationStrategy:
29
+ return GenerationStrategy.FULL_COLUMN
30
+
31
+ @property
32
+ def num_records_sampled(self) -> int:
33
+ return self._num_records_sampled
34
+
35
+ @functools.cached_property
36
+ def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
37
+ return self.resource_provider.seed_reader.create_duckdb_connection()
38
+
39
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
40
+ return concat_datasets([self.generate_from_scratch(len(data)), data])
41
+
42
+ def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
43
+ if num_records <= 0:
44
+ raise ValueError("🛑 `num_records` must be positive.")
45
+
46
+ if self._batch_reader is None:
47
+ self._reset_batch_reader(num_records)
48
+
49
+ return self._sample_records(num_records)
50
+
51
+ def _initialize(self) -> None:
52
+ self._num_records_sampled = 0
53
+ self._batch_reader = None
54
+ self._df_remaining = None
55
+ self._dataset_uri = self.resource_provider.seed_reader.get_dataset_uri()
56
+ self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
57
+ self._index_range = self._resolve_index_range()
58
+
59
+ def _validate_selection_strategy(self) -> None:
60
+ err_msg = None
61
+ if self.config.selection_strategy is not None:
62
+ if (
63
+ isinstance(self.config.selection_strategy, IndexRange)
64
+ and self.config.selection_strategy.end >= self._seed_dataset_size
65
+ ):
66
+ err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
67
+ elif (
68
+ isinstance(self.config.selection_strategy, PartitionBlock)
69
+ and self.config.selection_strategy.num_partitions > self._seed_dataset_size
70
+ ):
71
+ err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
72
+ if err_msg is not None:
73
+ raise SeedDatasetError(err_msg)
74
+
75
+ def _resolve_index_range(self) -> IndexRange | None:
76
+ self._validate_selection_strategy()
77
+ index_range = None
78
+ if self.config.selection_strategy is not None:
79
+ if isinstance(self.config.selection_strategy, IndexRange):
80
+ index_range = self.config.selection_strategy
81
+ elif isinstance(self.config.selection_strategy, PartitionBlock):
82
+ index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
83
+ return index_range
84
+
85
+ def _reset_batch_reader(self, num_records: int) -> None:
86
+ shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
87
+ shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
88
+
89
+ if self._index_range is not None:
90
+ # Use LIMIT and OFFSET for efficient index range filtering
91
+ # IndexRange uses 0-based indexing [start, end] inclusive
92
+ # OFFSET skips the first 'start' rows (0-based)
93
+ # LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
94
+ offset_value = self._index_range.start
95
+ limit_value = self._index_range.end - self._index_range.start + 1
96
+ read_query = f"""
97
+ SELECT * FROM '{self._dataset_uri}'
98
+ LIMIT {limit_value} OFFSET {offset_value}
99
+ """
100
+
101
+ read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
102
+ else:
103
+ read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
104
+ self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)
105
+
106
+ def _sample_records(self, num_records: int) -> pd.DataFrame:
107
+ logger.info(f"🌱 Sampling {num_records} records from seed dataset")
108
+ logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
109
+ logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
110
+ if self._index_range is not None:
111
+ if isinstance(self.config.selection_strategy, IndexRange):
112
+ logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
113
+ else:
114
+ logger.info(
115
+ f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
116
+ )
117
+ logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
118
+ df_batch = pd.DataFrame()
119
+ df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
120
+ num_zero_record_responses = 0
121
+
122
+ while len(df_sample) < num_records:
123
+ try:
124
+ df_batch = self._batch_reader.read_next_batch().to_pandas()
125
+ df_sample = pd.concat([df_sample, df_batch], ignore_index=True)
126
+ except StopIteration:
127
+ self._reset_batch_reader(num_records)
128
+
129
+ if len(df_batch) == 0:
130
+ num_zero_record_responses += 1
131
+ if num_zero_record_responses > MAX_ZERO_RECORD_RESPONSE_FACTOR * num_records:
132
+ raise RuntimeError(
133
+ "🛑 Something went wrong while reading from the datastore. "
134
+ "Please check your connection and try again. "
135
+ "If the issue persists, please contact support."
136
+ )
137
+
138
+ self._df_remaining = None
139
+ if len(df_sample) > num_records:
140
+ self._df_remaining = df_sample.iloc[num_records:].reset_index(drop=True)
141
+ df_sample = df_sample.iloc[:num_records]
142
+ self._num_records_sampled += len(df_sample)
143
+
144
+ return df_sample
@@ -0,0 +1,140 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from typing import TYPE_CHECKING
8
+
9
+ from data_designer.config.column_configs import ValidationColumnConfig
10
+ from data_designer.config.errors import InvalidConfigError
11
+ from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang
12
+ from data_designer.config.validator_params import ValidatorParamsT, ValidatorType
13
+ from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn
14
+ from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
15
+ from data_designer.engine.errors import DataDesignerRuntimeError
16
+ from data_designer.engine.validators import (
17
+ BaseValidator,
18
+ LocalCallableValidator,
19
+ PythonValidator,
20
+ RemoteValidator,
21
+ SQLValidator,
22
+ ValidationResult,
23
+ )
24
+ from data_designer.lazy_heavy_imports import pd
25
+
26
+ if TYPE_CHECKING:
27
+ import pandas as pd
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_validator_from_params(validator_type: ValidatorType, validator_params: ValidatorParamsT) -> BaseValidator:
33
+ if validator_type == ValidatorType.CODE:
34
+ if validator_params.code_lang == CodeLang.PYTHON:
35
+ return PythonValidator(validator_params)
36
+ elif validator_params.code_lang in SQL_DIALECTS:
37
+ return SQLValidator(validator_params)
38
+ elif validator_type == ValidatorType.REMOTE:
39
+ return RemoteValidator(validator_params)
40
+ else:
41
+ return LocalCallableValidator(validator_params)
42
+
43
+
44
+ class ValidationColumnGenerator(ColumnGeneratorFullColumn[ValidationColumnConfig]):
45
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
46
+ logger.info(f"🔍 Validating column {self.config.name!r} with {len(data)} records")
47
+ logger.info(f" |-- target columns: {self.config.target_columns}")
48
+ logger.info(f" |-- validator type: {self.config.validator_type}")
49
+ logger.info(f" |-- validator params: {self.config.validator_params}")
50
+ logger.info(f" |-- batch size: {self.config.batch_size}")
51
+
52
+ validator = get_validator_from_params(self.config.validator_type, self.config.validator_params)
53
+
54
+ # Check if the target columns are present in the dataset
55
+ missing_columns = set(self.config.target_columns) - set(data.columns)
56
+ if missing_columns:
57
+ raise InvalidConfigError(
58
+ f"Target columns {missing_columns} defined in validation column {self.config.name!r} are missing in dataset"
59
+ )
60
+
61
+ # Check whether to pass single columns or multiple columns to the validator
62
+ validate_columns_separately = False
63
+ if self.config.validator_type == ValidatorType.CODE and len(self.config.target_columns) > 1:
64
+ # Code validator expects single column input, so we validate each column separately
65
+ validate_columns_separately = True
66
+
67
+ columns_to_validate = [[col] for col in self.config.target_columns]
68
+ else:
69
+ columns_to_validate = [self.config.target_columns]
70
+
71
+ outputs_as_dicts = None
72
+ for cols in columns_to_validate:
73
+ # Filter the dataset to only include the target columns, and convert to a list of dictionaries
74
+ records = data[cols].to_dict(orient="records")
75
+
76
+ batched_records = [
77
+ records[batch_start : batch_start + self.config.batch_size]
78
+ for batch_start in range(0, len(records), self.config.batch_size)
79
+ ]
80
+
81
+ # Run validation in parallel or sequentially, depending on the validator type and parameters
82
+ if (
83
+ self.config.validator_type == ValidatorType.REMOTE
84
+ and self.config.validator_params.max_parallel_requests > 1
85
+ ):
86
+ concatenated_outputs = self._validate_in_parallel(validator, batched_records)
87
+ else:
88
+ concatenated_outputs = []
89
+ for batch in batched_records:
90
+ concatenated_outputs.extend(self._validate_batch(validator, batch))
91
+
92
+ if validate_columns_separately:
93
+ if outputs_as_dicts is None:
94
+ outputs_as_dicts = [{cols[0]: output.model_dump(mode="json")} for output in concatenated_outputs]
95
+ else:
96
+ for dict_output in outputs_as_dicts:
97
+ dict_output[cols[0]] = concatenated_outputs[0].model_dump(mode="json")
98
+ else:
99
+ outputs_as_dicts = [output.model_dump(mode="json") for output in concatenated_outputs]
100
+
101
+ validation_results = pd.DataFrame({self.config.name: outputs_as_dicts})
102
+ return pd.concat([data, validation_results], axis=1)
103
+
104
+ def _validate_in_parallel(self, validator: BaseValidator, batched_records: list[list[dict]]) -> pd.DataFrame:
105
+ """Run validation in parallel."""
106
+
107
+ outputs = [None] * len(batched_records)
108
+
109
+ def result_callback(result: ValidationResult, context: dict):
110
+ outputs[context["index"]] = result
111
+
112
+ def error_callback(error: Exception, context: dict):
113
+ outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
114
+
115
+ settings = self.resource_provider.run_config
116
+ with ConcurrentThreadExecutor(
117
+ max_workers=self.config.validator_params.max_parallel_requests,
118
+ column_name=self.config.name,
119
+ result_callback=result_callback,
120
+ error_callback=error_callback,
121
+ shutdown_error_rate=settings.shutdown_error_rate,
122
+ shutdown_error_window=settings.shutdown_error_window,
123
+ disable_early_shutdown=settings.disable_early_shutdown,
124
+ ) as executor:
125
+ for i, batch in enumerate(batched_records):
126
+ executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
127
+
128
+ if any(output is None for output in outputs):
129
+ raise DataDesignerRuntimeError("Validation task failed due to an unexpected error in parallel execution")
130
+
131
+ # Concatenate the outputs and convert to a DataFrame
132
+ return sum([output.data for output in outputs], [])
133
+
134
+ def _validate_batch(self, validator: BaseValidator, batch: list[dict]) -> ValidationResult:
135
+ try:
136
+ return validator.run_validation(batch)
137
+ except Exception as e:
138
+ error_to_display = str(e).replace("\n", "\n ") # add spaces to improve readability
139
+ logger.error(f"Batch could not be validated:\n {error_to_display}")
140
+ raise e
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from data_designer.config.base import ConfigBase
7
+ from data_designer.config.column_configs import (
8
+ EmbeddingColumnConfig,
9
+ ExpressionColumnConfig,
10
+ LLMCodeColumnConfig,
11
+ LLMJudgeColumnConfig,
12
+ LLMStructuredColumnConfig,
13
+ LLMTextColumnConfig,
14
+ ValidationColumnConfig,
15
+ )
16
+ from data_designer.config.column_types import DataDesignerColumnType
17
+ from data_designer.engine.column_generators.generators.base import ColumnGenerator
18
+ from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
19
+ from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
20
+ from data_designer.engine.column_generators.generators.llm_completion import (
21
+ LLMCodeCellGenerator,
22
+ LLMJudgeCellGenerator,
23
+ LLMStructuredCellGenerator,
24
+ LLMTextCellGenerator,
25
+ )
26
+ from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator
27
+ from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator
28
+ from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator
29
+ from data_designer.engine.dataset_builders.multi_column_configs import (
30
+ SamplerMultiColumnConfig,
31
+ SeedDatasetMultiColumnConfig,
32
+ )
33
+ from data_designer.engine.registry.base import TaskRegistry
34
+ from data_designer.plugins.plugin import PluginType
35
+ from data_designer.plugins.registry import PluginRegistry
36
+
37
+
38
+ class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ...
39
+
40
+
41
+ def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry:
42
+ registry = ColumnGeneratorRegistry()
43
+ registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig)
44
+ registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
45
+ registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
46
+ registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
47
+ registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
48
+ registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
49
+ registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
50
+ registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
51
+ registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
52
+ if with_plugins:
53
+ for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
54
+ registry.register(
55
+ DataDesignerColumnType(plugin.name),
56
+ plugin.impl_cls,
57
+ plugin.config_cls,
58
+ )
59
+
60
+ return registry
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from data_designer.engine.errors import DataDesignerError
7
+
8
+
9
+ class PromptTemplateRenderError(DataDesignerError): ...
10
+
11
+
12
+ class ExpressionTemplateRenderError(DataDesignerError): ...
13
+
14
+
15
+ class SeedDatasetError(DataDesignerError): ...
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from data_designer.config.column_types import DataDesignerColumnType
7
+ from data_designer.config.utils.type_helpers import resolve_string_enum
8
+ from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry
9
+ from data_designer.plugin_manager import PluginManager
10
+
11
+ plugin_manager = PluginManager()
12
+
13
+
14
+ def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
15
+ """Return True if the column type is used in the workflow execution DAG."""
16
+ column_type = resolve_string_enum(column_type, DataDesignerColumnType)
17
+ dag_column_types = {
18
+ DataDesignerColumnType.EXPRESSION,
19
+ DataDesignerColumnType.LLM_CODE,
20
+ DataDesignerColumnType.LLM_JUDGE,
21
+ DataDesignerColumnType.LLM_STRUCTURED,
22
+ DataDesignerColumnType.LLM_TEXT,
23
+ DataDesignerColumnType.VALIDATION,
24
+ DataDesignerColumnType.EMBEDDING,
25
+ }
26
+ dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
27
+ return column_type in dag_column_types
28
+
29
+
30
+ def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
31
+ """Return True if the column type is a model-generated column."""
32
+ column_type = resolve_string_enum(column_type, DataDesignerColumnType)
33
+ model_generated_column_types = {
34
+ DataDesignerColumnType.LLM_TEXT,
35
+ DataDesignerColumnType.LLM_CODE,
36
+ DataDesignerColumnType.LLM_STRUCTURED,
37
+ DataDesignerColumnType.LLM_JUDGE,
38
+ DataDesignerColumnType.EMBEDDING,
39
+ }
40
+ for plugin in plugin_manager.get_column_generator_plugins():
41
+ if issubclass(plugin.impl_cls, ColumnGeneratorWithModelRegistry):
42
+ model_generated_column_types.add(plugin.name)
43
+ return column_type in model_generated_column_types
@@ -0,0 +1,58 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from enum import Enum
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, create_model
9
+
10
+ from data_designer.config.column_configs import Score
11
+
12
+ SCORING_FORMAT = "* {score}: {description}"
13
+ SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}"
14
+
15
+
16
+ class BaseJudgeResponse(BaseModel):
17
+ """Base model for all rubrics."""
18
+
19
+ model_config = ConfigDict(use_enum_values=True)
20
+ reasoning: str = Field(..., description="Reasoning for the assigned score.")
21
+
22
+
23
+ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
24
+ """Convert score descriptions into a single text block."""
25
+ list_block = "\n".join(
26
+ [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
27
+ )
28
+ return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
29
+
30
+
31
+ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
32
+ """Create a JudgeResponse data type."""
33
+ enum_members = {}
34
+ for option in score.options.keys():
35
+ member_name = f"VALUE_{option}"
36
+ enum_members[member_name] = option
37
+
38
+ DynamicScaleEnum = Enum(f"{score.name}Enum", enum_members)
39
+ options = _stringify_scoring(score.options, enum_type=DynamicScaleEnum)
40
+
41
+ return create_model(
42
+ score.name,
43
+ __doc__=score.description if score.description else None,
44
+ __base__=BaseJudgeResponse,
45
+ score=(DynamicScaleEnum, Field(..., description=options)),
46
+ )
47
+
48
+
49
+ def create_judge_structured_output_model(
50
+ judge_responses: list[type[BaseJudgeResponse]],
51
+ ) -> type[BaseModel]:
52
+ """Create a JudgeStructuredOutput class dynamically."""
53
+ return create_model(
54
+ "JudgeStructuredOutput",
55
+ __doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
56
+ __base__=BaseModel,
57
+ **{response.__name__: (response, ...) for response in judge_responses},
58
+ )