data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,147 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+
6
+ import pandas as pd
7
+
8
+ from data_designer.config.column_configs import ValidationColumnConfig
9
+ from data_designer.config.errors import InvalidConfigError
10
+ from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang
11
+ from data_designer.config.validator_params import (
12
+ ValidatorParamsT,
13
+ ValidatorType,
14
+ )
15
+ from data_designer.engine.column_generators.generators.base import (
16
+ ColumnGenerator,
17
+ GenerationStrategy,
18
+ GeneratorMetadata,
19
+ )
20
+ from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
21
+ from data_designer.engine.errors import DataDesignerRuntimeError
22
+ from data_designer.engine.validators import (
23
+ BaseValidator,
24
+ LocalCallableValidator,
25
+ PythonValidator,
26
+ RemoteValidator,
27
+ SQLValidator,
28
+ ValidationResult,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def get_validator_from_params(validator_type: ValidatorType, validator_params: ValidatorParamsT) -> BaseValidator:
35
+ if validator_type == ValidatorType.CODE:
36
+ if validator_params.code_lang == CodeLang.PYTHON:
37
+ return PythonValidator(validator_params)
38
+ elif validator_params.code_lang in SQL_DIALECTS:
39
+ return SQLValidator(validator_params)
40
+ elif validator_type == ValidatorType.REMOTE:
41
+ return RemoteValidator(validator_params)
42
+ else:
43
+ return LocalCallableValidator(validator_params)
44
+
45
+
46
+ class ValidationColumnGenerator(ColumnGenerator[ValidationColumnConfig]):
47
+ @staticmethod
48
+ def metadata() -> GeneratorMetadata:
49
+ return GeneratorMetadata(
50
+ name="validate",
51
+ description="Validate data.",
52
+ generation_strategy=GenerationStrategy.FULL_COLUMN,
53
+ required_resources=None,
54
+ )
55
+
56
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
57
+ logger.info(f"🔍 Validating column {self.config.name!r} with {len(data)} records")
58
+ logger.info(f" |-- target columns: {self.config.target_columns}")
59
+ logger.info(f" |-- validator type: {self.config.validator_type}")
60
+ logger.info(f" |-- validator params: {self.config.validator_params}")
61
+ logger.info(f" |-- batch size: {self.config.batch_size}")
62
+
63
+ validator = get_validator_from_params(self.config.validator_type, self.config.validator_params)
64
+
65
+ # Check if the target columns are present in the dataset
66
+ missing_columns = set(self.config.target_columns) - set(data.columns)
67
+ if missing_columns:
68
+ raise InvalidConfigError(
69
+ f"Target columns {missing_columns} defined in validation column {self.config.name!r} are missing in dataset"
70
+ )
71
+
72
+ # Check whether to pass single columns or multiple columns to the validator
73
+ validate_columns_separately = False
74
+ if self.config.validator_type == ValidatorType.CODE and len(self.config.target_columns) > 1:
75
+ # Code validator expects single column input, so we validate each column separately
76
+ validate_columns_separately = True
77
+
78
+ columns_to_validate = [[col] for col in self.config.target_columns]
79
+ else:
80
+ columns_to_validate = [self.config.target_columns]
81
+
82
+ outputs_as_dicts = None
83
+ for cols in columns_to_validate:
84
+ # Filter the dataset to only include the target columns, and convert to a list of dictionaries
85
+ records = data[cols].to_dict(orient="records")
86
+
87
+ batched_records = [
88
+ records[batch_start : batch_start + self.config.batch_size]
89
+ for batch_start in range(0, len(records), self.config.batch_size)
90
+ ]
91
+
92
+ # Run validation in parallel or sequentially, depending on the validator type and parameters
93
+ if (
94
+ self.config.validator_type == ValidatorType.REMOTE
95
+ and self.config.validator_params.max_parallel_requests > 1
96
+ ):
97
+ concatenated_outputs = self._validate_in_parallel(validator, batched_records)
98
+ else:
99
+ concatenated_outputs = []
100
+ for batch in batched_records:
101
+ concatenated_outputs.extend(self._validate_batch(validator, batch))
102
+
103
+ if validate_columns_separately:
104
+ if outputs_as_dicts is None:
105
+ outputs_as_dicts = [{cols[0]: output.model_dump(mode="json")} for output in concatenated_outputs]
106
+ else:
107
+ for dict_output in outputs_as_dicts:
108
+ dict_output[cols[0]] = concatenated_outputs[0].model_dump(mode="json")
109
+ else:
110
+ outputs_as_dicts = [output.model_dump(mode="json") for output in concatenated_outputs]
111
+
112
+ validation_results = pd.DataFrame({self.config.name: outputs_as_dicts})
113
+ return pd.concat([data, validation_results], axis=1)
114
+
115
+ def _validate_in_parallel(self, validator: BaseValidator, batched_records: list[list[dict]]) -> pd.DataFrame:
116
+ """Run validation in parallel."""
117
+
118
+ outputs = [None] * len(batched_records)
119
+
120
+ def result_callback(result: ValidationResult, context: dict):
121
+ outputs[context["index"]] = result
122
+
123
+ def error_callback(error: Exception, context: dict):
124
+ outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
125
+
126
+ with ConcurrentThreadExecutor(
127
+ max_workers=self.config.validator_params.max_parallel_requests,
128
+ column_name=self.config.name,
129
+ result_callback=result_callback,
130
+ error_callback=error_callback,
131
+ ) as executor:
132
+ for i, batch in enumerate(batched_records):
133
+ executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
134
+
135
+ if any(output is None for output in outputs):
136
+ raise DataDesignerRuntimeError("Validation task failed due to an unexpected error in parallel execution")
137
+
138
+ # Concatenate the outputs and convert to a DataFrame
139
+ return sum([output.data for output in outputs], [])
140
+
141
+ def _validate_batch(self, validator: BaseValidator, batch: list[dict]) -> ValidationResult:
142
+ try:
143
+ return validator.run_validation(batch)
144
+ except Exception as e:
145
+ error_to_display = str(e).replace("\n", "\n ") # add spaces to improve readability
146
+ logger.error(f"Batch could not be validated:\n {error_to_display}")
147
+ raise e
@@ -0,0 +1,56 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.config.base import ConfigBase
5
+ from data_designer.config.column_configs import (
6
+ ExpressionColumnConfig,
7
+ LLMCodeColumnConfig,
8
+ LLMJudgeColumnConfig,
9
+ LLMStructuredColumnConfig,
10
+ LLMTextColumnConfig,
11
+ ValidationColumnConfig,
12
+ )
13
+ from data_designer.config.column_types import DataDesignerColumnType
14
+ from data_designer.engine.column_generators.generators.base import ColumnGenerator
15
+ from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
16
+ from data_designer.engine.column_generators.generators.llm_generators import (
17
+ LLMCodeCellGenerator,
18
+ LLMJudgeCellGenerator,
19
+ LLMStructuredCellGenerator,
20
+ LLMTextCellGenerator,
21
+ )
22
+ from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator
23
+ from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator
24
+ from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator
25
+ from data_designer.engine.dataset_builders.multi_column_configs import (
26
+ SamplerMultiColumnConfig,
27
+ SeedDatasetMultiColumnConfig,
28
+ )
29
+ from data_designer.engine.registry.base import TaskRegistry
30
+ from data_designer.plugins.plugin import PluginType
31
+ from data_designer.plugins.registry import PluginRegistry
32
+
33
+
34
+ class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ...
35
+
36
+
37
+ def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry:
38
+ registry = ColumnGeneratorRegistry()
39
+ registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig)
40
+ registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
41
+ registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
42
+ registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
43
+ registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
44
+ registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
45
+ registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
46
+ registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
47
+
48
+ if with_plugins:
49
+ for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
50
+ registry.register(
51
+ DataDesignerColumnType(plugin.name),
52
+ plugin.task_cls,
53
+ plugin.config_cls,
54
+ )
55
+
56
+ return registry
@@ -0,0 +1,13 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from data_designer.engine.errors import DataDesignerError
5
+
6
+
7
+ class PromptTemplateRenderError(DataDesignerError): ...
8
+
9
+
10
+ class ExpressionTemplateRenderError(DataDesignerError): ...
11
+
12
+
13
+ class SeedDatasetError(DataDesignerError): ...
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from enum import Enum
5
+ from typing import Type
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, create_model
8
+
9
+ from data_designer.config.column_configs import Score
10
+
11
+ SCORING_FORMAT = "* {score}: {description}"
12
+ SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}"
13
+
14
+
15
+ class BaseJudgeResponse(BaseModel):
16
+ """Base model for all rubrics."""
17
+
18
+ model_config = ConfigDict(use_enum_values=True)
19
+ reasoning: str = Field(..., description="Reasoning for the assigned score.")
20
+
21
+
22
+ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
23
+ """Convert score descriptions into a single text block."""
24
+ list_block = "\n".join(
25
+ [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
26
+ )
27
+ return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
28
+
29
+
30
+ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
31
+ """Create a JudgeResponse data type."""
32
+ enum_members = {}
33
+ for option in score.options.keys():
34
+ member_name = f"VALUE_{option}"
35
+ enum_members[member_name] = option
36
+
37
+ DynamicScaleEnum = Enum(f"{score.name}Enum", enum_members)
38
+ options = _stringify_scoring(score.options, enum_type=DynamicScaleEnum)
39
+
40
+ return create_model(
41
+ score.name,
42
+ __doc__=score.description if score.description else None,
43
+ __base__=BaseJudgeResponse,
44
+ score=(DynamicScaleEnum, Field(..., description=options)),
45
+ )
46
+
47
+
48
+ def create_judge_structured_output_model(
49
+ judge_responses: list[Type[BaseJudgeResponse]],
50
+ ) -> Type[BaseModel]:
51
+ """Create a JudgeStructuredOutput class dynamically."""
52
+ return create_model(
53
+ "JudgeStructuredOutput",
54
+ __doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
55
+ __base__=BaseModel,
56
+ **{response.__name__.lower(): (response, ...) for response in judge_responses},
57
+ )
@@ -0,0 +1,98 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import logging
6
+
7
+ from data_designer.config.column_configs import SingleColumnConfig
8
+ from data_designer.config.column_types import DataDesignerColumnType
9
+ from data_designer.config.models import ModelConfig
10
+ from data_designer.config.utils.code_lang import CodeLang
11
+ from data_designer.config.utils.misc import get_prompt_template_keywords
12
+ from data_designer.config.utils.type_helpers import StrEnum
13
+ from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
14
+ from data_designer.engine.column_generators.utils.judge_score_factory import (
15
+ create_judge_response_model,
16
+ create_judge_structured_output_model,
17
+ )
18
+ from data_designer.engine.models.recipes.base import ResponseRecipe
19
+ from data_designer.engine.models.recipes.response_recipes import (
20
+ CodeResponseRecipe,
21
+ PydanticResponseRecipe,
22
+ StructuredResponseRecipe,
23
+ TextResponseRecipe,
24
+ )
25
+ from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
26
+ from data_designer.engine.processing.ginja.exceptions import UserTemplateError, UserTemplateUnsupportedFiltersError
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PromptType(StrEnum):
32
+ SYSTEM_PROMPT = "system_prompt"
33
+ USER_PROMPT = "user_prompt"
34
+
35
+
36
+ class RecordBasedPromptRenderer(WithJinja2UserTemplateRendering):
37
+ def __init__(self, response_recipe: ResponseRecipe, *, error_message_context: dict[str, str] | None = None):
38
+ self.response_recipe = response_recipe
39
+ self._error_message_context = error_message_context
40
+
41
+ def render(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> str | None:
42
+ self._prepare_environment(prompt_template=prompt_template, record=record, prompt_type=prompt_type)
43
+ rendered_prompt = self.render_multi_template(prompt_type, record) if prompt_template else ""
44
+ recipe_applicator = (
45
+ self.response_recipe.apply_recipe_to_user_prompt
46
+ if prompt_type == PromptType.USER_PROMPT
47
+ else self.response_recipe.apply_recipe_to_system_prompt
48
+ )
49
+ return recipe_applicator(rendered_prompt)
50
+
51
+ def _prepare_environment(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> None:
52
+ try:
53
+ self.prepare_jinja2_multi_template_renderer(
54
+ template_name=prompt_type.value,
55
+ prompt_template=prompt_template,
56
+ dataset_variables=list(record.keys()),
57
+ )
58
+ except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
59
+ template_variables = get_prompt_template_keywords(prompt_template)
60
+ missing_columns = list(set(template_variables) - set(record.keys()))
61
+
62
+ error_msg = (
63
+ f"There was an error preparing the {prompt_type.value.replace('_', ' ')} "
64
+ "template. Please double check that the template is valid Jinja2 syntax, that all "
65
+ "referenced variables are defined, and that any filters you are using are supported."
66
+ )
67
+ if len(missing_columns) > 0:
68
+ error_msg += f"\nThe following {missing_columns} columns are missing!"
69
+ if self._error_message_context is not None:
70
+ error_msg += f"\n{json.dumps(self._error_message_context, indent=2)}"
71
+ logger.error(f"🛑 {error_msg}")
72
+ raise PromptTemplateRenderError(f"{exc!s} {error_msg}")
73
+
74
+
75
+ def create_response_recipe(
76
+ column_config: SingleColumnConfig, model_config: ModelConfig | None = None
77
+ ) -> ResponseRecipe:
78
+ if model_config and column_config.model_alias != model_config.alias:
79
+ raise ValueError(
80
+ f"Column config model alias {column_config.model_alias} does not match model config alias {model_config.alias}"
81
+ )
82
+ if column_config.column_type == DataDesignerColumnType.LLM_TEXT:
83
+ return TextResponseRecipe()
84
+ if column_config.column_type == DataDesignerColumnType.LLM_CODE:
85
+ return CodeResponseRecipe(
86
+ syntax=CodeLang.parse_lang(column_config.code_lang),
87
+ )
88
+ if column_config.column_type == DataDesignerColumnType.LLM_STRUCTURED:
89
+ return StructuredResponseRecipe(
90
+ json_schema=column_config.output_format,
91
+ )
92
+ if column_config.column_type == DataDesignerColumnType.LLM_JUDGE:
93
+ return PydanticResponseRecipe(
94
+ data_type=create_judge_structured_output_model(
95
+ [create_judge_response_model(s) for s in column_config.scores]
96
+ ),
97
+ )
98
+ raise ValueError(f"No response recipe found for column type: {column_config.column_type}")
@@ -0,0 +1,82 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import Generic, Type, TypeVar, get_origin
7
+
8
+ import pandas as pd
9
+
10
+ from data_designer.config.base import ConfigBase
11
+ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
12
+ from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType
13
+
14
+ DataT = TypeVar("DataT", dict, pd.DataFrame)
15
+ TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
16
+
17
+
18
+ class ConfigurableTaskMetadata(ConfigBase):
19
+ name: str
20
+ description: str
21
+ required_resources: list[ResourceType] | None
22
+
23
+
24
+ class ConfigurableTask(ABC, Generic[TaskConfigT]):
25
+ def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | None):
26
+ self._config = self.get_config_type().model_validate(config)
27
+ self._resource_provider = resource_provider
28
+ self._validate_resources()
29
+ self._validate()
30
+ self._initialize()
31
+
32
+ @classmethod
33
+ def get_config_type(cls) -> Type[TaskConfigT]:
34
+ for base in cls.__orig_bases__:
35
+ if hasattr(base, "__args__") and len(base.__args__) == 1:
36
+ arg = base.__args__[0]
37
+ origin = get_origin(arg) or arg
38
+ if isinstance(origin, type) and issubclass(origin, ConfigBase):
39
+ return base.__args__[0]
40
+ raise TypeError(
41
+ f"Could not determine config type for `{cls.__name__}`. Please ensure that the "
42
+ "`ConfigurableTask` is defined with a generic type argument, where the type argument "
43
+ "is a subclass of `ConfigBase`."
44
+ )
45
+
46
+ @property
47
+ def artifact_path(self) -> Path:
48
+ return self.artifact_storage.artifact_path
49
+
50
+ @property
51
+ def artifact_storage(self) -> ArtifactStorage:
52
+ return self.resource_provider.artifact_storage
53
+
54
+ @property
55
+ def base_dataset_path(self) -> Path:
56
+ return self.artifact_storage.base_dataset_path
57
+
58
+ @property
59
+ def config(self) -> TaskConfigT:
60
+ return self._config
61
+
62
+ @property
63
+ def resource_provider(self) -> ResourceProvider:
64
+ if self._resource_provider is None:
65
+ raise ValueError(f"No resource provider provided for the `{self.metadata().name}` task.")
66
+ return self._resource_provider
67
+
68
+ @staticmethod
69
+ @abstractmethod
70
+ def metadata() -> ConfigurableTaskMetadata: ...
71
+
72
+ def _initialize(self) -> None:
73
+ """An internal method for custom initialization logic, which will be called in the constructor."""
74
+
75
+ def _validate(self) -> None:
76
+ """An internal method for custom validation logic, which will be called in the constructor."""
77
+
78
+ def _validate_resources(self) -> None:
79
+ for resource in self.metadata().required_resources or []:
80
+ if resource is not None:
81
+ if getattr(self.resource_provider, ResourceType(resource).value) is None:
82
+ raise ValueError(f"Resource {resource} is required for the `{self.metadata().name}`")
@@ -0,0 +1,181 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import logging
6
+ from pathlib import Path
7
+ import shutil
8
+ from typing import Union
9
+
10
+ import pandas as pd
11
+ from pydantic import BaseModel, field_validator, model_validator
12
+
13
+ from data_designer.config.utils.io_helpers import read_parquet_dataset
14
+ from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
15
+ from data_designer.engine.dataset_builders.errors import ArtifactStorageError
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
20
+
21
+
22
+ class BatchStage(StrEnum):
23
+ PARTIAL_RESULT = "partial_results_path"
24
+ FINAL_RESULT = "final_dataset_path"
25
+ DROPPED_COLUMNS = "dropped_columns_dataset_path"
26
+
27
+
28
+ class ArtifactStorage(BaseModel):
29
+ artifact_path: Path | str
30
+ dataset_name: str = "dataset"
31
+ final_dataset_folder_name: str = "parquet-files"
32
+ partial_results_folder_name: str = "tmp-partial-parquet-files"
33
+ dropped_columns_folder_name: str = "dropped-columns-parquet-files"
34
+
35
+ @property
36
+ def artifact_path_exists(self) -> bool:
37
+ return self.artifact_path.exists()
38
+
39
+ @property
40
+ def base_dataset_path(self) -> Path:
41
+ return self.artifact_path / self.dataset_name
42
+
43
+ @property
44
+ def dropped_columns_dataset_path(self) -> Path:
45
+ return self.base_dataset_path / self.dropped_columns_folder_name
46
+
47
+ @property
48
+ def final_dataset_path(self) -> Path:
49
+ return self.base_dataset_path / self.final_dataset_folder_name
50
+
51
+ @property
52
+ def metadata_file_path(self) -> Path:
53
+ return self.base_dataset_path / "metadata.json"
54
+
55
+ @property
56
+ def partial_results_path(self) -> Path:
57
+ return self.base_dataset_path / self.partial_results_folder_name
58
+
59
+ @field_validator("artifact_path")
60
+ def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
61
+ v = Path(v)
62
+ if not v.is_dir():
63
+ raise ArtifactStorageError("Artifact path must exist and be a directory")
64
+ return v
65
+
66
+ @model_validator(mode="after")
67
+ def validate_folder_names(self):
68
+ folder_names = [
69
+ self.dataset_name,
70
+ self.final_dataset_folder_name,
71
+ self.partial_results_folder_name,
72
+ self.dropped_columns_folder_name,
73
+ ]
74
+
75
+ for name in folder_names:
76
+ if len(name) == 0:
77
+ raise ArtifactStorageError("🛑 Directory names must be non-empty strings.")
78
+
79
+ if len(set(folder_names)) != len(folder_names):
80
+ raise ArtifactStorageError("🛑 Folder names must be unique (no collisions allowed).")
81
+
82
+ invalid_chars = {"<", ">", ":", '"', "/", "\\", "|", "?", "*"}
83
+ for name in folder_names:
84
+ if any(char in invalid_chars for char in name):
85
+ raise ArtifactStorageError(f"🛑 Directory name '{name}' contains invalid characters.")
86
+
87
+ return self
88
+
89
+ @staticmethod
90
+ def mkdir_if_needed(path: Path | str) -> Path:
91
+ """Create the directory if it does not exist."""
92
+ path = Path(path)
93
+ if not path.exists():
94
+ logger.debug(f"📁 Creating directory: {path}")
95
+ path.mkdir(parents=True, exist_ok=True)
96
+ return path
97
+
98
+ @staticmethod
99
+ def read_parquet_files(path: Path) -> pd.DataFrame:
100
+ return read_parquet_dataset(path)
101
+
102
+ def create_batch_file_path(
103
+ self,
104
+ batch_number: int,
105
+ batch_stage: BatchStage,
106
+ ) -> Path:
107
+ if batch_number < 0:
108
+ raise ArtifactStorageError("🛑 Batch number must be non-negative.")
109
+ return self._get_stage_path(batch_stage) / BATCH_FILE_NAME_FORMAT.format(batch_number=batch_number)
110
+
111
+ def load_dataset(self, batch_stage: BatchStage = BatchStage.FINAL_RESULT) -> pd.DataFrame:
112
+ return read_parquet_dataset(self._get_stage_path(batch_stage))
113
+
114
+ def load_dataset_with_dropped_columns(self) -> pd.DataFrame:
115
+ # The pyarrow backend has better support for nested data types.
116
+ df = self.load_dataset()
117
+ if (
118
+ self.dropped_columns_dataset_path.exists()
119
+ and self.create_batch_file_path(0, BatchStage.DROPPED_COLUMNS).is_file()
120
+ ):
121
+ logger.debug("Concatenating dropped columns to the final dataset.")
122
+ df_dropped = self.load_dataset(batch_stage=BatchStage.DROPPED_COLUMNS)
123
+ if len(df_dropped) != len(df):
124
+ raise ArtifactStorageError(
125
+ "🛑 The dropped-columns dataset has a different number of rows than the main dataset. "
126
+ "Something unexpected must have happened to the dataset builder's artifacts."
127
+ )
128
+ # To ensure indexes are aligned and avoid silent misalignment (which would introduce NaNs),
129
+ # check that the indexes are identical before concatenation.
130
+ if not df.index.equals(df_dropped.index):
131
+ raise ArtifactStorageError(
132
+ "🛑 The indexes of the main and dropped columns DataFrames are not aligned. "
133
+ "Something unexpected must have happened to the dataset builder's artifacts."
134
+ )
135
+ df = pd.concat([df, df_dropped], axis=1)
136
+ return df
137
+
138
+ def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
139
+ partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT)
140
+ if not partial_result_path.exists():
141
+ raise ArtifactStorageError("🛑 Partial result file not found.")
142
+ self.mkdir_if_needed(self._get_stage_path(BatchStage.FINAL_RESULT))
143
+ final_file_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.FINAL_RESULT)
144
+ shutil.move(partial_result_path, final_file_path)
145
+ return final_file_path
146
+
147
+ def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
148
+ self.mkdir_if_needed(self.base_dataset_path)
149
+ with open(self.base_dataset_path / json_file_name, "w") as file:
150
+ json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
151
+ return self.base_dataset_path / json_file_name
152
+
153
+ def write_batch_to_parquet_file(
154
+ self,
155
+ batch_number: int,
156
+ dataframe: pd.DataFrame,
157
+ batch_stage: BatchStage,
158
+ ) -> Path:
159
+ file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
160
+ self.write_parquet_file(file_path.name, dataframe, batch_stage)
161
+ return file_path
162
+
163
+ def write_parquet_file(
164
+ self,
165
+ parquet_file_name: str,
166
+ dataframe: pd.DataFrame,
167
+ batch_stage: BatchStage,
168
+ ) -> Path:
169
+ self.mkdir_if_needed(self._get_stage_path(batch_stage))
170
+ file_path = self._get_stage_path(batch_stage) / parquet_file_name
171
+ dataframe.to_parquet(file_path, index=False)
172
+ return file_path
173
+
174
+ def write_metadata(self, metadata: dict) -> Path:
175
+ self.mkdir_if_needed(self.base_dataset_path)
176
+ with open(self.metadata_file_path, "w") as file:
177
+ json.dump(metadata, file)
178
+ return self.metadata_file_path
179
+
180
+ def _get_stage_path(self, stage: BatchStage) -> Path:
181
+ return getattr(self, resolve_string_enum(stage, BatchStage).value)