data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +8 -11
  5. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -121
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -48
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc1.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,234 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from numbers import Number
8
- from typing import TYPE_CHECKING, Any
9
-
10
- import tiktoken
11
-
12
- from data_designer.config.analysis.column_statistics import (
13
- CategoricalDistribution,
14
- ColumnDistributionType,
15
- MissingValue,
16
- NumericalDistribution,
17
- )
18
- from data_designer.config.column_configs import (
19
- LLMTextColumnConfig,
20
- )
21
- from data_designer.engine.column_generators.utils.prompt_renderer import (
22
- PromptType,
23
- RecordBasedPromptRenderer,
24
- create_response_recipe,
25
- )
26
- from data_designer.lazy_heavy_imports import np, pa, pd
27
-
28
- if TYPE_CHECKING:
29
- import numpy as np
30
- import pandas as pd
31
- import pyarrow as pa
32
-
33
- RANDOM_SEED = 42
34
- MAX_PROMPT_SAMPLE_SIZE = 1000
35
- TOKENIZER = tiktoken.get_encoding("cl100k_base")
36
- WARNING_PREFIX = "⚠️ Error during column profile calculation: "
37
- TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD = 0.1
38
-
39
- logger = logging.getLogger(__name__)
40
-
41
-
42
- def calculate_column_distribution(
43
- column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
44
- ) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
45
- distribution_type = ColumnDistributionType(distribution_type)
46
- try:
47
- if distribution_type == ColumnDistributionType.CATEGORICAL:
48
- return {
49
- "distribution_type": ColumnDistributionType.CATEGORICAL,
50
- "distribution": CategoricalDistribution.from_series(df[column_name]),
51
- }
52
-
53
- if distribution_type == ColumnDistributionType.NUMERICAL:
54
- return {
55
- "distribution_type": ColumnDistributionType.NUMERICAL,
56
- "distribution": NumericalDistribution.from_series(df[column_name]),
57
- }
58
- except Exception as e:
59
- logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
60
- return {
61
- "distribution_type": ColumnDistributionType.UNKNOWN,
62
- "distribution": MissingValue.CALCULATION_FAILED,
63
- }
64
-
65
-
66
- def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
67
- try:
68
- _df = pd.DataFrame(df[column_name].apply(ensure_hashable))
69
-
70
- if has_pyarrow_backend(df):
71
- pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
72
- simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
73
- else:
74
- # We do not log a warning at the column-level because it would be too noisy.
75
- # However, there is a logged warning at the dataset-profiler level.
76
- try:
77
- simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
78
- except Exception:
79
- simple_dtype = MissingValue.CALCULATION_FAILED
80
- pyarrow_dtype = "n/a"
81
-
82
- return {
83
- "pyarrow_dtype": pyarrow_dtype,
84
- "simple_dtype": simple_dtype,
85
- "num_records": len(_df[column_name]),
86
- "num_null": _df[column_name].isnull().sum(),
87
- "num_unique": _df[column_name].nunique(),
88
- }
89
- except Exception as e:
90
- logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
91
- return {
92
- "pyarrow_dtype": MissingValue.CALCULATION_FAILED,
93
- "simple_dtype": MissingValue.CALCULATION_FAILED,
94
- "num_records": MissingValue.CALCULATION_FAILED,
95
- "num_null": MissingValue.CALCULATION_FAILED,
96
- "num_unique": MissingValue.CALCULATION_FAILED,
97
- }
98
-
99
-
100
- def calculate_input_token_stats(
101
- column_config: LLMTextColumnConfig, df: pd.DataFrame
102
- ) -> dict[str, float | MissingValue]:
103
- try:
104
- num_tokens = []
105
- num_samples = min(MAX_PROMPT_SAMPLE_SIZE, len(df))
106
- renderer = RecordBasedPromptRenderer(response_recipe=create_response_recipe(column_config))
107
- for record in df.sample(num_samples, random_state=RANDOM_SEED).to_dict(orient="records"):
108
- system_prompt = renderer.render(
109
- prompt_template=column_config.system_prompt, record=record, prompt_type=PromptType.SYSTEM_PROMPT
110
- )
111
- prompt = renderer.render(
112
- prompt_template=column_config.prompt, record=record, prompt_type=PromptType.USER_PROMPT
113
- )
114
- concatenated_prompt = str(system_prompt + "\n\n" + prompt)
115
- num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
116
- except Exception as e:
117
- logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
118
- return {
119
- "input_tokens_mean": MissingValue.CALCULATION_FAILED,
120
- "input_tokens_median": MissingValue.CALCULATION_FAILED,
121
- "input_tokens_stddev": MissingValue.CALCULATION_FAILED,
122
- }
123
- return {
124
- "input_tokens_mean": np.mean(num_tokens),
125
- "input_tokens_median": np.median(num_tokens),
126
- "input_tokens_stddev": np.std(num_tokens),
127
- }
128
-
129
-
130
- def calculate_output_token_stats(
131
- column_config: LLMTextColumnConfig, df: pd.DataFrame
132
- ) -> dict[str, float | MissingValue]:
133
- try:
134
- tokens_per_record = df[column_config.name].apply(
135
- lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
136
- )
137
- return {
138
- "output_tokens_mean": tokens_per_record.mean(),
139
- "output_tokens_median": tokens_per_record.median(),
140
- "output_tokens_stddev": tokens_per_record.std(),
141
- }
142
- except Exception as e:
143
- logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}")
144
- return {
145
- "output_tokens_mean": MissingValue.CALCULATION_FAILED,
146
- "output_tokens_median": MissingValue.CALCULATION_FAILED,
147
- "output_tokens_stddev": MissingValue.CALCULATION_FAILED,
148
- }
149
-
150
-
151
- def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
152
- return {
153
- **calculate_input_token_stats(column_config, df),
154
- **calculate_output_token_stats(column_config, df),
155
- }
156
-
157
-
158
- def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
159
- try:
160
- return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
161
- except Exception as e:
162
- logger.warning(
163
- f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
164
- )
165
- return {"num_valid_records": MissingValue.CALCULATION_FAILED}
166
-
167
-
168
- def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
169
- if isinstance(pyarrow_dtype, pa.ListType):
170
- return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
171
- if isinstance(pyarrow_dtype, pa.StructType):
172
- return "dict"
173
- return convert_to_simple_dtype(str(pyarrow_dtype))
174
-
175
-
176
- def convert_to_simple_dtype(dtype: str) -> str:
177
- if "int" in dtype:
178
- return "int"
179
- if "double" in dtype:
180
- return "float"
181
- if "float" in dtype:
182
- return "float"
183
- if "str" in dtype:
184
- return "string"
185
- if "timestamp" in dtype:
186
- return "timestamp"
187
- if "time" in dtype:
188
- return "time"
189
- if "date" in dtype:
190
- return "date"
191
- return dtype
192
-
193
-
194
- def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
195
- df_no_nulls = df[column_name].dropna()
196
- if len(df_no_nulls) == 0:
197
- return MissingValue.CALCULATION_FAILED
198
- dtype = type(df_no_nulls.iloc[0]).__name__
199
- return convert_to_simple_dtype(dtype)
200
-
201
-
202
- def ensure_hashable(x: Any) -> str:
203
- """
204
- Makes a best effort turn known unhashable types to a hashable
205
- string representation that preserves both structure and values.
206
- """
207
- if isinstance(x, (Number, bool)) or x is None:
208
- return x
209
-
210
- if isinstance(x, dict):
211
- # Sort by keys and convert key-value pairs to tuples
212
- return str(sorted([(str(k), ensure_hashable(v)) for k, v in x.items()]))
213
-
214
- if isinstance(x, (list, tuple, set, np.ndarray)):
215
- # Recursively make all elements hashable
216
- return str(sorted([ensure_hashable(e) for e in x]))
217
-
218
- return str(x)
219
-
220
-
221
- def ensure_boolean(v: bool | str | int | None) -> bool:
222
- if isinstance(v, (bool, np.bool_)):
223
- return bool(v)
224
- if isinstance(v, (int, float, np.integer, np.floating)) and v in [0, 1, 0.0, 1.0]:
225
- return bool(v)
226
- if isinstance(v, (str, np.str_)) and v.lower() in ["true", "false"]:
227
- return v.lower() == "true"
228
- if v is None:
229
- return False
230
- raise ValueError(f"Invalid boolean value: {v}")
231
-
232
-
233
- def has_pyarrow_backend(df: pd.DataFrame) -> bool:
234
- return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
@@ -1,132 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from collections import defaultdict
8
- from typing import TYPE_CHECKING, Any
9
-
10
- from data_designer.config.analysis.column_profilers import JudgeScoreDistributions, JudgeScoreSample
11
- from data_designer.config.analysis.column_statistics import (
12
- CategoricalDistribution,
13
- ColumnDistributionType,
14
- MissingValue,
15
- NumericalDistribution,
16
- )
17
- from data_designer.config.column_configs import LLMJudgeColumnConfig
18
- from data_designer.lazy_heavy_imports import pd
19
-
20
- if TYPE_CHECKING:
21
- import pandas as pd
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- def extract_judge_score_distributions(
27
- column_config: LLMJudgeColumnConfig, df: pd.DataFrame
28
- ) -> JudgeScoreDistributions | MissingValue:
29
- scores = defaultdict(list)
30
- reasoning = defaultdict(list)
31
-
32
- # Aggregate results as dicts of form {score_name: <result>}.
33
- histograms = {}
34
- distributions = {}
35
- distribution_types = {}
36
-
37
- for score in column_config.scores:
38
- is_numerical = True
39
- name = score.name
40
- for results in df[column_config.name]:
41
- try:
42
- score = results[name].get("score", None)
43
-
44
- if _can_be_converted_to_int(score):
45
- score = int(score)
46
- else:
47
- score = str(score)
48
- is_numerical = False
49
-
50
- scores[name].append(score)
51
- reasoning[name].append(results[name].get("reasoning", "No reasoning provided"))
52
- except Exception as e:
53
- logger.warning(f"⚠️ Failed to extract judge score for '{name}': {e}")
54
- return MissingValue.OUTPUT_FORMAT_ERROR
55
-
56
- try:
57
- series = pd.Series(scores[name], name=name)
58
- cat_dist = CategoricalDistribution.from_series(series)
59
-
60
- # For judge scores, build a categorical histogram, since numerical scores are integers.
61
- histograms[name] = cat_dist.histogram
62
-
63
- if is_numerical:
64
- distribution_types[name] = ColumnDistributionType.NUMERICAL
65
- distributions[name] = NumericalDistribution.from_series(series)
66
- else:
67
- distribution_types[name] = ColumnDistributionType.CATEGORICAL
68
- distributions[name] = cat_dist
69
-
70
- except Exception as e:
71
- logger.warning(f"⚠️ Failed to calculate judge score distribution for '{name}': {e}")
72
- distribution_types[name] = ColumnDistributionType.UNKNOWN
73
- distributions[name] = MissingValue.CALCULATION_FAILED
74
- histograms[name] = MissingValue.CALCULATION_FAILED
75
-
76
- return JudgeScoreDistributions(
77
- scores=dict(scores),
78
- reasoning=dict(reasoning),
79
- distribution_types=distribution_types,
80
- distributions=distributions,
81
- histograms=histograms,
82
- )
83
-
84
-
85
- def sample_scores_and_reasoning(
86
- scores: list[int | str],
87
- reasoning: list[str],
88
- num_samples: int,
89
- random_seed: int | None = None,
90
- ) -> list[JudgeScoreSample]:
91
- if len(scores) != len(reasoning):
92
- raise ValueError("scores and reasoning must have the same length")
93
-
94
- if len(scores) == 0:
95
- raise ValueError("scores and reasoning must not be empty")
96
-
97
- if num_samples <= 0:
98
- raise ValueError("num_samples must be greater than 0")
99
-
100
- df_samples = pd.DataFrame({"score": scores, "reasoning": reasoning})
101
-
102
- if len(scores) <= num_samples:
103
- return [JudgeScoreSample(score=score, reasoning=reasoning) for score, reasoning in zip(scores, reasoning)]
104
-
105
- # Sample maintaining original proportions from each category (int or str)
106
- # Calculate the frequency of each score category
107
- score_category_counts = df_samples["score"].value_counts()
108
-
109
- # If more categories than samples, pick one sample from each of the most frequent categories
110
- if len(score_category_counts) >= num_samples:
111
- top_categories = score_category_counts.head(num_samples).index
112
- samples = pd.concat(
113
- [df_samples[df_samples["score"] == cat].sample(n=1, random_state=random_seed) for cat in top_categories],
114
- ignore_index=True,
115
- )
116
- else:
117
- # Sample proportionally to maintain original category ratios
118
- # Create weights based on the original frequency of each score
119
- weights = df_samples["score"].map(score_category_counts)
120
- samples = df_samples.sample(n=num_samples, weights=weights, random_state=random_seed)
121
-
122
- return [
123
- JudgeScoreSample(score=row["score"], reasoning=row["reasoning"]) for row in samples.to_dict(orient="records")
124
- ]
125
-
126
-
127
- def _can_be_converted_to_int(value: Any) -> bool:
128
- try:
129
- int(value)
130
- return True
131
- except (ValueError, TypeError):
132
- return False
@@ -1,2 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
@@ -1,2 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
@@ -1,122 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import functools
7
- import logging
8
- from abc import ABC, abstractmethod
9
- from enum import Enum
10
- from typing import TYPE_CHECKING, overload
11
-
12
- from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
13
- from data_designer.lazy_heavy_imports import pd
14
-
15
- if TYPE_CHECKING:
16
- import pandas as pd
17
-
18
- from data_designer.config.models import BaseInferenceParams, ModelConfig
19
- from data_designer.engine.models.facade import ModelFacade
20
- from data_designer.engine.models.registry import ModelRegistry
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- class GenerationStrategy(str, Enum):
26
- CELL_BY_CELL = "cell_by_cell"
27
- FULL_COLUMN = "full_column"
28
-
29
-
30
- class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
31
- @property
32
- def can_generate_from_scratch(self) -> bool:
33
- return False
34
-
35
- @staticmethod
36
- @abstractmethod
37
- def get_generation_strategy() -> GenerationStrategy: ...
38
-
39
- @overload
40
- @abstractmethod
41
- def generate(self, data: dict) -> dict: ...
42
-
43
- @overload
44
- @abstractmethod
45
- def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
46
-
47
- @abstractmethod
48
- def generate(self, data: DataT) -> DataT: ...
49
-
50
- def log_pre_generation(self) -> None:
51
- """A shared method to log info before the generator's `generate` method is called.
52
-
53
- The idea is for dataset builders to call this method for all generators before calling their
54
- `generate` method. This is to avoid logging the same information multiple times when running
55
- generators in parallel.
56
- """
57
-
58
-
59
- class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
60
- @property
61
- def can_generate_from_scratch(self) -> bool:
62
- return True
63
-
64
- @abstractmethod
65
- def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
66
-
67
-
68
- class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
69
- @property
70
- def model_registry(self) -> ModelRegistry:
71
- return self.resource_provider.model_registry
72
-
73
- def get_model(self, model_alias: str) -> ModelFacade:
74
- return self.model_registry.get_model(model_alias=model_alias)
75
-
76
- def get_model_config(self, model_alias: str) -> ModelConfig:
77
- return self.model_registry.get_model_config(model_alias=model_alias)
78
-
79
- def get_model_provider_name(self, model_alias: str) -> str:
80
- provider = self.model_registry.get_model_provider(model_alias=model_alias)
81
- return provider.name
82
-
83
-
84
- class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC):
85
- @functools.cached_property
86
- def model(self) -> ModelFacade:
87
- return self.get_model(model_alias=self.config.model_alias)
88
-
89
- @functools.cached_property
90
- def model_config(self) -> ModelConfig:
91
- return self.get_model_config(model_alias=self.config.model_alias)
92
-
93
- @functools.cached_property
94
- def inference_parameters(self) -> BaseInferenceParams:
95
- return self.model_config.inference_parameters
96
-
97
- def log_pre_generation(self) -> None:
98
- logger.info(
99
- f"{self.config.get_column_emoji()} {self.config.column_type} model config for column '{self.config.name}'"
100
- )
101
- logger.info(f" |-- model: {self.model_config.model!r}")
102
- logger.info(f" |-- model alias: {self.config.model_alias!r}")
103
- logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}")
104
- logger.info(f" |-- inference parameters: {self.inference_parameters.format_for_display()}")
105
-
106
-
107
- class ColumnGeneratorCellByCell(ColumnGenerator[TaskConfigT], ABC):
108
- @staticmethod
109
- def get_generation_strategy() -> GenerationStrategy:
110
- return GenerationStrategy.CELL_BY_CELL
111
-
112
- @abstractmethod
113
- def generate(self, data: dict) -> dict: ...
114
-
115
-
116
- class ColumnGeneratorFullColumn(ColumnGenerator[TaskConfigT], ABC):
117
- @staticmethod
118
- def get_generation_strategy() -> GenerationStrategy:
119
- return GenerationStrategy.FULL_COLUMN
120
-
121
- @abstractmethod
122
- def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
@@ -1,35 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from pydantic import BaseModel, computed_field
7
-
8
- from data_designer.config.column_configs import EmbeddingColumnConfig
9
- from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
10
- from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
11
-
12
-
13
- class EmbeddingGenerationResult(BaseModel):
14
- embeddings: list[list[float]]
15
-
16
- @computed_field
17
- def num_embeddings(self) -> int:
18
- return len(self.embeddings)
19
-
20
- @computed_field
21
- def dimension(self) -> int:
22
- return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0
23
-
24
-
25
- class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
26
- @staticmethod
27
- def get_generation_strategy() -> GenerationStrategy:
28
- return GenerationStrategy.CELL_BY_CELL
29
-
30
- def generate(self, data: dict) -> dict:
31
- deserialized_record = deserialize_json_values(data)
32
- input_texts = parse_list_string(deserialized_record[self.config.target_column])
33
- embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
34
- data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
35
- return data
@@ -1,55 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from typing import TYPE_CHECKING
8
-
9
- from data_designer.config.column_configs import ExpressionColumnConfig
10
- from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn
11
- from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError
12
- from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
13
- from data_designer.engine.processing.utils import deserialize_json_values
14
- from data_designer.lazy_heavy_imports import pd
15
-
16
- if TYPE_CHECKING:
17
- import pandas as pd
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]):
23
- def generate(self, data: pd.DataFrame) -> pd.DataFrame:
24
- logger.info(f"🧩 Generating column `{self.config.name}` from expression")
25
-
26
- missing_columns = list(set(self.config.required_columns) - set(data.columns))
27
- if len(missing_columns) > 0:
28
- error_msg = (
29
- f"There was an error preparing the Jinja2 expression template. "
30
- f"The following columns {missing_columns} are missing!"
31
- )
32
- raise ExpressionTemplateRenderError(error_msg)
33
-
34
- self.prepare_jinja2_template_renderer(self.config.expr, data.columns.to_list())
35
- records = []
36
- for record in data.to_dict(orient="records"):
37
- record[self.config.name] = self._cast_type(self.render_template(deserialize_json_values(record)))
38
- records.append(record)
39
-
40
- return pd.DataFrame(records)
41
-
42
- def _cast_type(self, value: str) -> str | float | int | bool:
43
- if self.config.dtype == "str":
44
- return value
45
- elif self.config.dtype == "float":
46
- return float(value)
47
- elif self.config.dtype == "int":
48
- return int(float(value))
49
- elif self.config.dtype == "bool":
50
- try:
51
- return bool(int(float(value)))
52
- except ValueError:
53
- return bool(f"{value}".lower() == "true")
54
- else:
55
- raise ValueError(f"Invalid dtype: {self.config.dtype}")