data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,128 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections import defaultdict
5
+ import logging
6
+ from typing import Any, Optional, Union
7
+
8
+ import pandas as pd
9
+
10
+ from data_designer.config.analysis.column_profilers import JudgeScoreDistributions, JudgeScoreSample
11
+ from data_designer.config.analysis.column_statistics import (
12
+ CategoricalDistribution,
13
+ ColumnDistributionType,
14
+ MissingValue,
15
+ NumericalDistribution,
16
+ )
17
+ from data_designer.config.column_configs import LLMJudgeColumnConfig
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def extract_judge_score_distributions(
23
+ column_config: LLMJudgeColumnConfig, df: pd.DataFrame
24
+ ) -> Union[JudgeScoreDistributions, MissingValue]:
25
+ scores = defaultdict(list)
26
+ reasoning = defaultdict(list)
27
+
28
+ # Aggregate results as dicts of form {score_name: <result>}.
29
+ histograms = {}
30
+ distributions = {}
31
+ distribution_types = {}
32
+
33
+ for score in column_config.scores:
34
+ is_numerical = True
35
+ name = score.name.lower()
36
+ for results in df[column_config.name]:
37
+ try:
38
+ score = results[name].get("score", None)
39
+
40
+ if _can_be_converted_to_int(score):
41
+ score = int(score)
42
+ else:
43
+ score = str(score)
44
+ is_numerical = False
45
+
46
+ scores[name].append(score)
47
+ reasoning[name].append(results[name].get("reasoning", "No reasoning provided"))
48
+ except Exception as e:
49
+ logger.warning(f"⚠️ Failed to extract judge score for '{name}': {e}")
50
+ return MissingValue.OUTPUT_FORMAT_ERROR
51
+
52
+ try:
53
+ series = pd.Series(scores[name], name=name)
54
+ cat_dist = CategoricalDistribution.from_series(series)
55
+
56
+ # For judge scores, build a categorical histogram, since numerical scores are integers.
57
+ histograms[name] = cat_dist.histogram
58
+
59
+ if is_numerical:
60
+ distribution_types[name] = ColumnDistributionType.NUMERICAL
61
+ distributions[name] = NumericalDistribution.from_series(series)
62
+ else:
63
+ distribution_types[name] = ColumnDistributionType.CATEGORICAL
64
+ distributions[name] = cat_dist
65
+
66
+ except Exception as e:
67
+ logger.warning(f"⚠️ Failed to calculate judge score distribution for '{name}': {e}")
68
+ distribution_types[name] = ColumnDistributionType.UNKNOWN
69
+ distributions[name] = MissingValue.CALCULATION_FAILED
70
+ histograms[name] = MissingValue.CALCULATION_FAILED
71
+
72
+ return JudgeScoreDistributions(
73
+ scores=dict(scores),
74
+ reasoning=dict(reasoning),
75
+ distribution_types=distribution_types,
76
+ distributions=distributions,
77
+ histograms=histograms,
78
+ )
79
+
80
+
81
+ def sample_scores_and_reasoning(
82
+ scores: list[Union[int, str]],
83
+ reasoning: list[str],
84
+ num_samples: int,
85
+ random_seed: Optional[int] = None,
86
+ ) -> list[JudgeScoreSample]:
87
+ if len(scores) != len(reasoning):
88
+ raise ValueError("scores and reasoning must have the same length")
89
+
90
+ if len(scores) == 0:
91
+ raise ValueError("scores and reasoning must not be empty")
92
+
93
+ if num_samples <= 0:
94
+ raise ValueError("num_samples must be greater than 0")
95
+
96
+ df_samples = pd.DataFrame({"score": scores, "reasoning": reasoning})
97
+
98
+ if len(scores) <= num_samples:
99
+ return [JudgeScoreSample(score=score, reasoning=reasoning) for score, reasoning in zip(scores, reasoning)]
100
+
101
+ # Sample maintaining original proportions from each category (int or str)
102
+ # Calculate the frequency of each score category
103
+ score_category_counts = df_samples["score"].value_counts()
104
+
105
+ # If more categories than samples, pick one sample from each of the most frequent categories
106
+ if len(score_category_counts) >= num_samples:
107
+ top_categories = score_category_counts.head(num_samples).index
108
+ samples = pd.concat(
109
+ [df_samples[df_samples["score"] == cat].sample(n=1, random_state=random_seed) for cat in top_categories],
110
+ ignore_index=True,
111
+ )
112
+ else:
113
+ # Sample proportionally to maintain original category ratios
114
+ # Create weights based on the original frequency of each score
115
+ weights = df_samples["score"].map(score_category_counts)
116
+ samples = df_samples.sample(n=num_samples, weights=weights, random_state=random_seed)
117
+
118
+ return [
119
+ JudgeScoreSample(score=row["score"], reasoning=row["reasoning"]) for row in samples.to_dict(orient="records")
120
+ ]
121
+
122
+
123
+ def _can_be_converted_to_int(value: Any) -> bool:
124
+ try:
125
+ int(value)
126
+ return True
127
+ except (ValueError, TypeError):
128
+ return False
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,61 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import overload
6
+
7
+ import pandas as pd
8
+
9
+ from data_designer.config.utils.type_helpers import StrEnum
10
+ from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
11
+
12
+
13
+ class GenerationStrategy(StrEnum):
14
+ CELL_BY_CELL = "cell_by_cell"
15
+ FULL_COLUMN = "full_column"
16
+
17
+
18
+ class GeneratorMetadata(ConfigurableTaskMetadata):
19
+ generation_strategy: GenerationStrategy
20
+
21
+
22
+ class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
23
+ @property
24
+ def can_generate_from_scratch(self) -> bool:
25
+ return False
26
+
27
+ @property
28
+ def generation_strategy(self) -> GenerationStrategy:
29
+ return self.metadata().generation_strategy
30
+
31
+ @staticmethod
32
+ @abstractmethod
33
+ def metadata() -> GeneratorMetadata: ...
34
+
35
+ @overload
36
+ @abstractmethod
37
+ def generate(self, data: dict) -> dict: ...
38
+
39
+ @overload
40
+ @abstractmethod
41
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
42
+
43
+ @abstractmethod
44
+ def generate(self, data: DataT) -> DataT: ...
45
+
46
+ def log_pre_generation(self) -> None:
47
+ """A shared method to log info before the generator's `generate` method is called.
48
+
49
+ The idea is for dataset builders to call this method for all generators before calling their
50
+ `generate` method. This is to avoid logging the same information multiple times when running
51
+ generators in parallel.
52
+ """
53
+
54
+
55
+ class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
56
+ @property
57
+ def can_generate_from_scratch(self) -> bool:
58
+ return True
59
+
60
+ @abstractmethod
61
+ def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
@@ -0,0 +1,63 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+
6
+ import pandas as pd
7
+
8
+ from data_designer.config.column_configs import ExpressionColumnConfig
9
+ from data_designer.engine.column_generators.generators.base import (
10
+ ColumnGenerator,
11
+ GenerationStrategy,
12
+ GeneratorMetadata,
13
+ )
14
+ from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError
15
+ from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
16
+ from data_designer.engine.processing.utils import deserialize_json_values
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGenerator[ExpressionColumnConfig]):
22
+ @staticmethod
23
+ def metadata() -> GeneratorMetadata:
24
+ return GeneratorMetadata(
25
+ name="expression_generator",
26
+ description="Generate a column from a jinja2 expression.",
27
+ generation_strategy=GenerationStrategy.FULL_COLUMN,
28
+ required_resources=None,
29
+ )
30
+
31
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
32
+ logger.info(f"🧩 Generating column `{self.config.name}` from expression")
33
+
34
+ missing_columns = list(set(self.config.required_columns) - set(data.columns))
35
+ if len(missing_columns) > 0:
36
+ error_msg = (
37
+ f"There was an error preparing the Jinja2 expression template. "
38
+ f"The following columns {missing_columns} are missing!"
39
+ )
40
+ raise ExpressionTemplateRenderError(error_msg)
41
+
42
+ self.prepare_jinja2_template_renderer(self.config.expr, data.columns.to_list())
43
+ records = []
44
+ for record in data.to_dict(orient="records"):
45
+ record[self.config.name] = self._cast_type(self.render_template(deserialize_json_values(record)))
46
+ records.append(record)
47
+
48
+ return pd.DataFrame(records)
49
+
50
+ def _cast_type(self, value: str) -> str | float | int | bool:
51
+ if self.config.dtype == "str":
52
+ return value
53
+ elif self.config.dtype == "float":
54
+ return float(value)
55
+ elif self.config.dtype == "int":
56
+ return int(float(value))
57
+ elif self.config.dtype == "bool":
58
+ try:
59
+ return bool(int(float(value)))
60
+ except ValueError:
61
+ return bool(f"{value}".lower() == "true")
62
+ else:
63
+ raise ValueError(f"Invalid dtype: {self.config.dtype}")
@@ -0,0 +1,172 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import functools
5
+ import logging
6
+
7
+ from data_designer.config.column_configs import (
8
+ LLMCodeColumnConfig,
9
+ LLMJudgeColumnConfig,
10
+ LLMStructuredColumnConfig,
11
+ LLMTextColumnConfig,
12
+ )
13
+ from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
14
+ from data_designer.config.models import InferenceParameters, ModelConfig
15
+ from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16
+ from data_designer.engine.column_generators.generators.base import (
17
+ ColumnGenerator,
18
+ GenerationStrategy,
19
+ GeneratorMetadata,
20
+ )
21
+ from data_designer.engine.column_generators.utils.prompt_renderer import (
22
+ PromptType,
23
+ RecordBasedPromptRenderer,
24
+ create_response_recipe,
25
+ )
26
+ from data_designer.engine.models.facade import ModelFacade
27
+ from data_designer.engine.models.recipes.base import ResponseRecipe
28
+ from data_designer.engine.processing.utils import deserialize_json_values
29
+ from data_designer.engine.resources.resource_provider import ResourceType
30
+
31
+ DEFAULT_MAX_CONVERSATION_RESTARTS = 5
32
+ DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class WithLLMGeneration:
39
+ @functools.cached_property
40
+ def model(self) -> ModelFacade:
41
+ return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
42
+
43
+ @functools.cached_property
44
+ def model_config(self) -> ModelConfig:
45
+ return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
46
+
47
+ @functools.cached_property
48
+ def inference_parameters(self) -> InferenceParameters:
49
+ return self.model_config.inference_parameters
50
+
51
+ @functools.cached_property
52
+ def prompt_renderer(self) -> RecordBasedPromptRenderer:
53
+ return RecordBasedPromptRenderer(
54
+ response_recipe=self.response_recipe,
55
+ error_message_context={
56
+ "column_name": self.config.name,
57
+ "column_type": self.config.column_type,
58
+ "model_alias": self.config.model_alias,
59
+ },
60
+ )
61
+
62
+ @functools.cached_property
63
+ def response_recipe(self) -> ResponseRecipe:
64
+ return create_response_recipe(self.config, self.model_config)
65
+
66
+ @property
67
+ def max_conversation_correction_steps(self) -> int:
68
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
69
+
70
+ @property
71
+ def max_conversation_restarts(self) -> int:
72
+ return DEFAULT_MAX_CONVERSATION_RESTARTS
73
+
74
+ def generate(self, data: dict) -> dict:
75
+ deserialized_record = deserialize_json_values(data)
76
+
77
+ multi_modal_context = None
78
+ if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
79
+ multi_modal_context = [
80
+ context.get_context(deserialized_record) for context in self.config.multi_modal_context
81
+ ]
82
+
83
+ response, reasoning_trace = self.model.generate(
84
+ prompt=self.prompt_renderer.render(
85
+ record=deserialized_record,
86
+ prompt_template=self.config.prompt,
87
+ prompt_type=PromptType.USER_PROMPT,
88
+ ),
89
+ system_prompt=self.prompt_renderer.render(
90
+ record=deserialized_record,
91
+ prompt_template=self.config.system_prompt,
92
+ prompt_type=PromptType.SYSTEM_PROMPT,
93
+ ),
94
+ parser=self.response_recipe.parse,
95
+ multi_modal_context=multi_modal_context,
96
+ max_correction_steps=self.max_conversation_correction_steps,
97
+ max_conversation_restarts=self.max_conversation_restarts,
98
+ purpose=f"running generation for column '{self.config.name}'",
99
+ **self.inference_parameters.generate_kwargs,
100
+ )
101
+
102
+ data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
103
+
104
+ if reasoning_trace:
105
+ data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace
106
+
107
+ return data
108
+
109
+ def log_pre_generation(self) -> None:
110
+ emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
111
+ logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
112
+ logger.info(f" |-- column name: {self.config.name!r}")
113
+ logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
114
+ if self.model_config.provider is None:
115
+ logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
116
+
117
+ def _get_provider_name(self) -> str:
118
+ model_alias = self.model_config.alias
119
+ provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
120
+ return provider.name
121
+
122
+
123
+ class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
124
+ @staticmethod
125
+ def metadata() -> GeneratorMetadata:
126
+ return GeneratorMetadata(
127
+ name="llm_text_generator",
128
+ description="Generate a new dataset cell from a prompt template",
129
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
130
+ required_resources=[ResourceType.MODEL_REGISTRY],
131
+ )
132
+
133
+
134
+ class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
135
+ @staticmethod
136
+ def metadata() -> GeneratorMetadata:
137
+ return GeneratorMetadata(
138
+ name="llm_code_generator",
139
+ description="Generate a new dataset cell from a prompt template",
140
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
141
+ required_resources=[ResourceType.MODEL_REGISTRY],
142
+ )
143
+
144
+
145
+ class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
146
+ @staticmethod
147
+ def metadata() -> GeneratorMetadata:
148
+ return GeneratorMetadata(
149
+ name="llm_structured_generator",
150
+ description="Generate a new dataset cell from a prompt template",
151
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
152
+ required_resources=[ResourceType.MODEL_REGISTRY],
153
+ )
154
+
155
+
156
+ class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
157
+ @staticmethod
158
+ def metadata() -> GeneratorMetadata:
159
+ return GeneratorMetadata(
160
+ name="llm_judge_generator",
161
+ description="Judge a new dataset cell based on a set of rubrics",
162
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
163
+ required_resources=[ResourceType.MODEL_REGISTRY],
164
+ )
165
+
166
+ @property
167
+ def max_conversation_correction_steps(self) -> int:
168
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
169
+
170
+ @property
171
+ def max_conversation_restarts(self) -> int:
172
+ return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from functools import partial
5
+ import logging
6
+ import random
7
+ from typing import Callable
8
+
9
+ import pandas as pd
10
+
11
+ from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
12
+ from data_designer.engine.column_generators.generators.base import (
13
+ FromScratchColumnGenerator,
14
+ GenerationStrategy,
15
+ GeneratorMetadata,
16
+ )
17
+ from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
18
+ from data_designer.engine.processing.utils import concat_datasets
19
+ from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
20
+ from data_designer.engine.resources.resource_provider import ResourceType
21
+ from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
22
+ from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
23
+ from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
29
+ @staticmethod
30
+ def metadata() -> GeneratorMetadata:
31
+ return GeneratorMetadata(
32
+ name="sampler_column_generator",
33
+ description="Generate columns using sampling-based method.",
34
+ generation_strategy=GenerationStrategy.FULL_COLUMN,
35
+ required_resources=[ResourceType.BLOB_STORAGE],
36
+ )
37
+
38
+ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
39
+ df_samplers = self.generate_from_scratch(len(data))
40
+ return concat_datasets([data, df_samplers])
41
+
42
+ def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
43
+ sampling_generator = self._prepare_for_generation(num_records)
44
+ return sampling_generator.generate(num_records)
45
+
46
+ @property
47
+ def _needs_person_generator(self) -> bool:
48
+ columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
49
+ return any(c.params.locale in LOCALES_WITH_MANAGED_DATASETS for c in columns)
50
+
51
+ @property
52
+ def _person_generator_loader(self) -> Callable[[bool], ManagedDatasetGenerator]:
53
+ return partial(load_person_data_sampler, blob_storage=self.resource_provider.blob_storage)
54
+
55
+ def _create_sampling_dataset_generator(self) -> SamplingDatasetGenerator:
56
+ return SamplingDatasetGenerator(
57
+ sampler_columns=self.config,
58
+ person_generator_loader=(self._person_generator_loader if self._needs_person_generator else None),
59
+ )
60
+
61
+ def _log_person_generation_if_needed(self) -> None:
62
+ if self._needs_person_generator:
63
+ columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
64
+ emoji = random.choice(["🧑‍🎨", "🙋‍♂️", "🙋‍♀️", "🧑‍🚀", "👩‍🎤", "👨‍🍳", "👩‍🔬", "👨‍💻", "👩‍💼"])
65
+ log_msg = f"🎲 {emoji} Initializing person generation"
66
+ if any(c.params.with_synthetic_personas for c in columns):
67
+ log_msg += " ⚡️ with synthetic personas ⚡️"
68
+ logger.info(log_msg)
69
+
70
+ def _prepare_for_generation(self, num_records: int) -> SamplingDatasetGenerator:
71
+ logger.info(
72
+ f"🎲 Preparing samplers to generate {num_records} records across {len(self.config.columns)} columns"
73
+ )
74
+ self._log_person_generation_if_needed()
75
+ return self._create_sampling_dataset_generator()
@@ -0,0 +1,149 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import functools
5
+ import logging
6
+
7
+ import duckdb
8
+ import pandas as pd
9
+
10
+ from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
11
+ from data_designer.engine.column_generators.generators.base import (
12
+ FromScratchColumnGenerator,
13
+ GenerationStrategy,
14
+ GeneratorMetadata,
15
+ )
16
+ from data_designer.engine.column_generators.utils.errors import SeedDatasetError
17
+ from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
18
+ from data_designer.engine.processing.utils import concat_datasets
19
+ from data_designer.engine.resources.resource_provider import ResourceType
20
+
21
+ MAX_ZERO_RECORD_RESPONSE_FACTOR = 2
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColumnConfig]):
27
+ @staticmethod
28
+ def metadata() -> GeneratorMetadata:
29
+ return GeneratorMetadata(
30
+ name="seed_dataset_column_generator",
31
+ description="Sample columns from a seed dataset.",
32
+ generation_strategy=GenerationStrategy.FULL_COLUMN,
33
+ required_resources=[ResourceType.DATASTORE],
34
+ )
35
+
36
+ @property
37
+ def num_records_sampled(self) -> int:
38
+ return self._num_records_sampled
39
+
40
+ @functools.cached_property
41
+ def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
42
+ return self.resource_provider.datastore.create_duckdb_connection()
43
+
44
+ def generate(self, dataset: pd.DataFrame) -> pd.DataFrame:
45
+ return concat_datasets([self.generate_from_scratch(len(dataset)), dataset])
46
+
47
+ def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
48
+ if num_records <= 0:
49
+ raise ValueError("🛑 `num_records` must be positive.")
50
+
51
+ if self._batch_reader is None:
52
+ self._reset_batch_reader(num_records)
53
+
54
+ return self._sample_records(num_records)
55
+
56
+ def _initialize(self) -> None:
57
+ self._num_records_sampled = 0
58
+ self._batch_reader = None
59
+ self._df_remaining = None
60
+ self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
61
+ self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
62
+ self._index_range = self._resolve_index_range()
63
+
64
+ def _validate_selection_strategy(self) -> None:
65
+ err_msg = None
66
+ if self.config.selection_strategy is not None:
67
+ if (
68
+ isinstance(self.config.selection_strategy, IndexRange)
69
+ and self.config.selection_strategy.end >= self._seed_dataset_size
70
+ ):
71
+ err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
72
+ elif (
73
+ isinstance(self.config.selection_strategy, PartitionBlock)
74
+ and self.config.selection_strategy.num_partitions > self._seed_dataset_size
75
+ ):
76
+ err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
77
+ if err_msg is not None:
78
+ raise SeedDatasetError(err_msg)
79
+
80
+ def _resolve_index_range(self) -> IndexRange | None:
81
+ self._validate_selection_strategy()
82
+ index_range = None
83
+ if self.config.selection_strategy is not None:
84
+ if isinstance(self.config.selection_strategy, IndexRange):
85
+ index_range = self.config.selection_strategy
86
+ elif isinstance(self.config.selection_strategy, PartitionBlock):
87
+ index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
88
+ return index_range
89
+
90
+ def _reset_batch_reader(self, num_records: int) -> None:
91
+ shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
92
+ shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
93
+
94
+ if self._index_range is not None:
95
+ # Use LIMIT and OFFSET for efficient index range filtering
96
+ # IndexRange uses 0-based indexing [start, end] inclusive
97
+ # OFFSET skips the first 'start' rows (0-based)
98
+ # LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
99
+ offset_value = self._index_range.start
100
+ limit_value = self._index_range.end - self._index_range.start + 1
101
+ read_query = f"""
102
+ SELECT * FROM '{self._dataset_uri}'
103
+ LIMIT {limit_value} OFFSET {offset_value}
104
+ """
105
+
106
+ read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
107
+ else:
108
+ read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
109
+ self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)
110
+
111
+ def _sample_records(self, num_records: int) -> pd.DataFrame:
112
+ logger.info(f"🌱 Sampling {num_records} records from seed dataset")
113
+ logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
114
+ logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
115
+ if self._index_range is not None:
116
+ if isinstance(self.config.selection_strategy, IndexRange):
117
+ logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
118
+ else:
119
+ logger.info(
120
+ f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
121
+ )
122
+ logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
123
+ df_batch = pd.DataFrame()
124
+ df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
125
+ num_zero_record_responses = 0
126
+
127
+ while len(df_sample) < num_records:
128
+ try:
129
+ df_batch = self._batch_reader.read_next_batch().to_pandas()
130
+ df_sample = pd.concat([df_sample, df_batch], ignore_index=True)
131
+ except StopIteration:
132
+ self._reset_batch_reader(num_records)
133
+
134
+ if len(df_batch) == 0:
135
+ num_zero_record_responses += 1
136
+ if num_zero_record_responses > MAX_ZERO_RECORD_RESPONSE_FACTOR * num_records:
137
+ raise RuntimeError(
138
+ "🛑 Something went wrong while reading from the datastore. "
139
+ "Please check your connection and try again. "
140
+ "If the issue persists, please contact support."
141
+ )
142
+
143
+ self._df_remaining = None
144
+ if len(df_sample) > num_records:
145
+ self._df_remaining = df_sample.iloc[num_records:].reset_index(drop=True)
146
+ df_sample = df_sample.iloc[:num_records]
147
+ self._num_records_sampled += len(df_sample)
148
+
149
+ return df_sample